mirror of https://github.com/cbeuw/Cloak
Compare commits
8 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
bea21b7166 | |
|
|
8317f447d1 | |
|
|
ba7f29d9e6 | |
|
|
10c17c4aca | |
|
|
930e647226 | |
|
|
7be1586973 | |
|
|
a3c3a9b03f | |
|
|
2dd48ef71e |
|
|
@ -1,91 +0,0 @@
|
|||
name: Build and test
|
||||
on: [ push ]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ ubuntu-latest, macos-latest, windows-latest ]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '^1.24' # The Go version to download (if necessary) and use.
|
||||
- run: go test -race -coverprofile coverage.txt -coverpkg ./... -covermode atomic ./...
|
||||
- uses: codecov/codecov-action@v4
|
||||
with:
|
||||
files: coverage.txt
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
compat-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
encryption-method: [ plain, chacha20-poly1305 ]
|
||||
num-conn: [ 0, 1, 4 ]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '^1.24'
|
||||
- name: Build Cloak
|
||||
run: make
|
||||
- name: Create configs
|
||||
run: |
|
||||
mkdir config
|
||||
cat << EOF > config/ckclient.json
|
||||
{
|
||||
"Transport": "direct",
|
||||
"ProxyMethod": "iperf",
|
||||
"EncryptionMethod": "${{ matrix.encryption-method }}",
|
||||
"UID": "Q4GAXHVgnDLXsdTpw6bmoQ==",
|
||||
"PublicKey": "4dae/bF43FKGq+QbCc5P/E/MPM5qQeGIArjmJEHiZxc=",
|
||||
"ServerName": "cloudflare.com",
|
||||
"BrowserSig": "firefox",
|
||||
"NumConn": ${{ matrix.num-conn }}
|
||||
}
|
||||
EOF
|
||||
cat << EOF > config/ckserver.json
|
||||
{
|
||||
"ProxyBook": {
|
||||
"iperf": [
|
||||
"tcp",
|
||||
"127.0.0.1:5201"
|
||||
]
|
||||
},
|
||||
"BindAddr": [
|
||||
":8443"
|
||||
],
|
||||
"BypassUID": [
|
||||
"Q4GAXHVgnDLXsdTpw6bmoQ=="
|
||||
],
|
||||
"RedirAddr": "cloudflare.com",
|
||||
"PrivateKey": "AAaskZJRPIAbiuaRLHsvZPvE6gzOeSjg+ZRg1ENau0Y="
|
||||
}
|
||||
EOF
|
||||
- name: Start iperf3 server
|
||||
run: docker run -d --name iperf-server --network host ajoergensen/iperf3:latest --server
|
||||
- name: Test new client against old server
|
||||
run: |
|
||||
docker run -d --name old-cloak-server --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-server -c config/ckserver.json --verbosity debug
|
||||
build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug | tee new-cloak-client.log &
|
||||
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
|
||||
docker stop old-cloak-server
|
||||
- name: Test old client against new server
|
||||
run: |
|
||||
build/ck-server -c config/ckserver.json --verbosity debug | tee new-cloak-server.log &
|
||||
docker run -d --name old-cloak-client --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug
|
||||
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
|
||||
docker stop old-cloak-client
|
||||
- name: Dump docker logs
|
||||
if: always()
|
||||
run: |
|
||||
docker container logs iperf-server > iperf-server.log
|
||||
docker container logs old-cloak-server > old-cloak-server.log
|
||||
docker container logs old-cloak-client > old-cloak-client.log
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.encryption-method }}-${{ matrix.num-conn }}-conn-logs
|
||||
path: ./*.log
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
name: Create Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build
|
||||
run: |
|
||||
export PATH=${PATH}:`go env GOPATH`/bin
|
||||
v=${GITHUB_REF#refs/*/} ./release.sh
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
files: release/*
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
cbeuw/cloak
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
corpus/
|
||||
suppressions/
|
||||
crashers/
|
||||
*.zip
|
||||
.idea/
|
||||
build/
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
FROM golang:latest
|
||||
|
||||
RUN git clone https://github.com/cbeuw/Cloak.git
|
||||
WORKDIR Cloak
|
||||
RUN make
|
||||
5
Makefile
5
Makefile
|
|
@ -16,6 +16,11 @@ server:
|
|||
go build -ldflags "-X main.version=${version}" ./cmd/ck-server
|
||||
mv ck-server* ./build
|
||||
|
||||
server_pprof:
|
||||
mkdir -p build
|
||||
go build -ldflags "-X main.version=${version}" -tags pprof ./cmd/ck-server
|
||||
mv ck-server* ./build
|
||||
|
||||
install:
|
||||
mv build/ck-* /usr/local/bin
|
||||
|
||||
|
|
|
|||
245
README.md
245
README.md
|
|
@ -1,239 +1,74 @@
|
|||
[](https://github.com/cbeuw/Cloak/actions)
|
||||
[](https://codecov.io/gh/cbeuw/Cloak)
|
||||
[](https://goreportcard.com/report/github.com/cbeuw/Cloak)
|
||||
[](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)
|
||||
# Cloak
|
||||
A Shadowsocks plugin that obfuscates the traffic as normal HTTPS traffic to non-blocked websites through domain fronting and disguises the proxy server as a normal webserver.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/7034308/96387206-3e214100-1198-11eb-8917-689d7c56e0cd.png" />
|
||||
<img src="https://user-images.githubusercontent.com/7034308/155593583-f22bcfe2-ac22-4afb-9288-1a0e8a791a0d.svg" />
|
||||
</p>
|
||||
Cloak multiplexes all traffic through a fixed amount of underlying TCP connections which eliminates the TCP handshake overhead when using vanilla Shadowsocks. Cloak also provides user management, allowing multiple users to connect to the proxy server using **one single port**. It also provides QoS controls for individual users such as upload and download credit limit, as well as bandwidth control.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://user-images.githubusercontent.com/7034308/155629720-54dd8758-ec98-4fed-b603-623f0ad83b6c.svg" />
|
||||
</p>
|
||||
To external observers (such as the GFW), Cloak is completely transparent and behaves like an ordinary HTTPS server. This is done through several [cryptographic mechanisms](https://github.com/cbeuw/Cloak/wiki/Cryptographic-Mechanisms). This eliminates the risk of being detected by traffic analysis and/or active probing.
|
||||
|
||||
Cloak is a [pluggable transport](https://datatracker.ietf.org/meeting/103/materials/slides-103-pearg-pt-slides-01) that enhances
|
||||
traditional proxy tools like OpenVPN to evade [sophisticated censorship](https://en.wikipedia.org/wiki/Deep_packet_inspection) and [data discrimination](https://en.wikipedia.org/wiki/Net_bias).
|
||||
|
||||
Cloak is not a standalone proxy program. Rather, it works by masquerading proxied traffic as normal web browsing
|
||||
activities. In contrast to traditional tools which have very prominent traffic fingerprints and can be blocked by simple filtering rules,
|
||||
it's very difficult to precisely target Cloak with little false positives. This increases the collateral damage to censorship actions as
|
||||
attempts to block Cloak could also damage services the censor state relies on.
|
||||
|
||||
To any third party observer, a host running Cloak server is indistinguishable from an innocent web server. Both while
|
||||
passively observing traffic flow to and from the server, as well as while actively probing the behaviours of a Cloak
|
||||
server. This is achieved through the use a series
|
||||
of [cryptographic steganography techniques](https://github.com/cbeuw/Cloak/wiki/Steganography-and-encryption).
|
||||
|
||||
Cloak can be used in conjunction with any proxy program that tunnels traffic through TCP or
|
||||
UDP, such as Shadowsocks, OpenVPN and Tor. Multiple proxy servers can be running on the same server host and
|
||||
Cloak server will act as a reverse proxy, bridging clients with their desired proxy end.
|
||||
|
||||
Cloak multiplexes traffic through multiple underlying TCP connections which reduces head-of-line blocking and eliminates
|
||||
TCP handshake overhead. This also makes the traffic pattern more similar to real websites.
|
||||
|
||||
Cloak provides multi-user support, allowing multiple clients to connect to the proxy server on the same port (443 by
|
||||
default). It also provides traffic management features such as usage credit and bandwidth control. This allows a proxy
|
||||
server to serve multiple users even if the underlying proxy software wasn't designed for multiple users
|
||||
|
||||
Cloak also supports tunneling through an intermediary CDN server such as Amazon Cloudfront. Such services are so widely used,
|
||||
attempts to disrupt traffic to them can lead to very high collateral damage for the censor.
|
||||
|
||||
## Quick Start
|
||||
|
||||
To quickly deploy Cloak with Shadowsocks on a server, you can run
|
||||
this [script](https://github.com/HirbodBehnam/Shadowsocks-Cloak-Installer/blob/master/Cloak2-Installer.sh) written by
|
||||
@HirbodBehnam
|
||||
|
||||
Table of Contents
|
||||
=================
|
||||
|
||||
* [Quick Start](#quick-start)
|
||||
* [Build](#build)
|
||||
* [Configuration](#configuration)
|
||||
* [Server](#server)
|
||||
* [Client](#client)
|
||||
* [Setup](#setup)
|
||||
* [Server](#server-1)
|
||||
* [To add users](#to-add-users)
|
||||
* [Unrestricted users](#unrestricted-users)
|
||||
* [Users subject to bandwidth and credit controls](#users-subject-to-bandwidth-and-credit-controls)
|
||||
* [Client](#client-1)
|
||||
* [Support me](#support-me)
|
||||
This project is based on a previous project [GoQuiet](https://github.com/cbeuw/GoQuiet). Through multiplexing, Cloak provides a siginifcant reduction in webpage loading time compared to GoQuiet (from 10% to 50+%, depending on the amount of content on the webpage, see [benchmarks](https://github.com/cbeuw/Cloak/wiki/Web-page-loading-benchmarks)).
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
git clone https://github.com/cbeuw/Cloak
|
||||
cd Cloak
|
||||
go get ./...
|
||||
make
|
||||
```
|
||||
|
||||
Built binaries will be in `build` folder.
|
||||
Simply `make client` and `make server`. Output binary will be in the build folder.
|
||||
Do `make server_pprof` if you want to access the live profiling data.
|
||||
|
||||
## Configuration
|
||||
|
||||
Examples of configuration files can be found under `example_config` folder.
|
||||
|
||||
### Server
|
||||
`WebServerAddr` is the redirection address and port when the incoming traffic is not from shadowsocks. It should correspond to the IP record of the `ServerName` set in `ckclient.json`.
|
||||
|
||||
`RedirAddr` is the redirection address when the incoming traffic is not from a Cloak client. Ideally it should be set to
|
||||
a major website allowed by the censor (e.g. `www.bing.com`)
|
||||
`PrivateKey` is the static curve25519 Diffie-Hellman private key.
|
||||
|
||||
`BindAddr` is a list of addresses Cloak will bind and listen to (e.g. `[":443",":80"]` to listen to port 443 and 80 on
|
||||
all interfaces)
|
||||
`AdminUID` is the UID of the admin user in base64.
|
||||
|
||||
`ProxyBook` is an object whose key is the name of the ProxyMethod used on the client-side (case-sensitive). Its value is
|
||||
an array whose first element is the protocol, and the second element is an `IP:PORT` string of the upstream proxy server
|
||||
that Cloak will forward the traffic to.
|
||||
`DatabasePath` is the path to userinfo.db. If userinfo.db doesn't exist in this directory, Cloak will create one automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as / (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so and will raise an error. See Issue #13.**
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"ProxyBook": {
|
||||
"shadowsocks": [
|
||||
"tcp",
|
||||
"localhost:51443"
|
||||
],
|
||||
"openvpn": [
|
||||
"tcp",
|
||||
"localhost:12345"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
|
||||
|
||||
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
|
||||
|
||||
`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`.
|
||||
|
||||
`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will
|
||||
create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`.
|
||||
This field also has no effect if `AdminUID` isn't a valid UID or is empty.
|
||||
|
||||
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
|
||||
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
|
||||
`BackupDirPath` is the path to save the backups of userinfo.db whenever you delete a user. If left blank, Cloak will attempt to create a folder called db-backup under its working directory. This may not be desired. See notes above.
|
||||
|
||||
### Client
|
||||
|
||||
`UID` is your UID in base64.
|
||||
|
||||
`Transport` can be either `direct` or `CDN`. If the server host wishes you to connect to it directly, use `direct`. If
|
||||
instead a CDN is used, use `CDN`.
|
||||
`PublicKey` is the static curve25519 public key, given by the server admin.
|
||||
|
||||
`PublicKey` is the static curve25519 public key in base64, given by the server admin.
|
||||
`ServerName` is the domain you want to make the GFW think you are visiting.
|
||||
|
||||
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the
|
||||
server's `ProxyBook` exactly.
|
||||
`TicketTimeHint` is the time needed for a session ticket to expire and a new one to be generated. Leave it as the default.
|
||||
|
||||
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` (
|
||||
synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport
|
||||
security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically
|
||||
random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH
|
||||
encryption and authentication (via AEAD or similar techniques).**
|
||||
`NumConn` is the amount of underlying TCP connections you want to use.
|
||||
|
||||
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
|
||||
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. Use `random` to randomize the server name for every connection made.
|
||||
|
||||
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new
|
||||
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
|
||||
the alternative domains.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"ServerName": "bing.com",
|
||||
"AlternativeNames": ["cloudflare.com", "github.com"]
|
||||
}
|
||||
```
|
||||
|
||||
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
|
||||
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
|
||||
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with
|
||||
the CDN server, this domain name will be used in the `Host` header of the HTTP request to ask the CDN server to
|
||||
establish a WebSocket connection with this host.
|
||||
|
||||
`CDNWsUrlPath` is the url path used to build websocket request sent under `CDN` mode, and also only has effect
|
||||
when `Transport` is set to `CDN`. If unset, it will default to "/". This option is used to build the first line of the
|
||||
HTTP request after a TLS session is extablished. It's mainly for a Cloak server behind a reverse proxy, while only
|
||||
requests under specific url path are forwarded.
|
||||
|
||||
`NumConn` is the amount of underlying TCP connections you want to use. The default of 4 should be appropriate for most
|
||||
people. Setting it too high will hinder the performance. Setting it to 0 will disable connection multiplexing and each
|
||||
TCP connection will spawn a separate short-lived session that will be closed after it is terminated. This makes it
|
||||
behave like GoQuiet. This maybe useful for people with unstable connections.
|
||||
|
||||
`BrowserSig` is the browser you want to **appear** to be using. It's not relevant to the browser you are actually using.
|
||||
Currently, `chrome`, `firefox` and `safari` are supported.
|
||||
|
||||
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
|
||||
Cloak server. Zero or negative value disables it. Default is 0 (disabled). Warning: Enabling it might make your server
|
||||
more detectable as a proxy, but it will make the Cloak client detect internet interruption more quickly.
|
||||
|
||||
`StreamTimeout` is the number of seconds of Cloak waits for an incoming connection from a proxy program to send any
|
||||
data, after which the connection will be closed by Cloak. Cloak will not enforce any timeout on TCP connections after it
|
||||
is established.
|
||||
`MaskBrowser` is the browser you want to **make the GFW _think_ you are using, it has NOTHING to do with the web browser or any web application you are using on your machine**. Currently, `chrome` and `firefox` are supported.
|
||||
|
||||
## Setup
|
||||
### For the administrator of the server
|
||||
**Run this script: https://github.com/HirbodBehnam/Shadowsocks-Cloak-Installer/blob/master/Shadowsocks-Cloak-Installer.sh (thanks to [@HirbodBehnam](https://github.com/HirbodBehnam))** or do it manually:
|
||||
|
||||
### Server
|
||||
0. [Install and configure shadowsocks-libev on your server](https://github.com/shadowsocks/shadowsocks-libev#installation)
|
||||
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo. If you wish to build it, make sure you fetch the dependencies using `go get github.com/boltdb/bolt`, `go get github.com/juju/ratelimit` and `go get golang.org/x/crypto/curve25519`
|
||||
2. Run ck-server -k. The base64 string before the comma is the **public** key to be given to users, the one after the comma is the **private** key to be kept secret
|
||||
3. Run `ck-server -u`. This will be used as the AdminUID
|
||||
4. Put the private key and the AdminUID you obtained previously into config/ckserver.json
|
||||
5. Edit the configuration file of shadowsocks-libev (default location is /etc/shadowsocks-libev/config.json). Let `server_port` be `443`, `plugin` be the full path to the ck-server binary and `plugin_opts` be the full path to ckserver.json. If the fields `plugin` and `plugin_opts` were not present originally, add these fields to the config file.
|
||||
6. Run ss-server as root (because we are binding to TCP port 443)
|
||||
|
||||
0. Install at least one underlying proxy server (e.g. OpenVPN, Shadowsocks).
|
||||
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
|
||||
2. Run `ck-server -key`. The **public** should be given to users, the **private** key should be kept secret.
|
||||
3. (Skip if you only want to add unrestricted users) Run `ck-server -uid`. The new UID will be used as `AdminUID`.
|
||||
4. Copy example_config/ckserver.json into a desired location. Change `PrivateKey` to the private key you just obtained;
|
||||
change `AdminUID` to the UID you just obtained.
|
||||
5. Configure your underlying proxy server so that they all listen on localhost. Edit `ProxyBook` in the configuration
|
||||
file accordingly
|
||||
6. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
|
||||
Run `sudo ck-server -c <path to ckserver.json>`. ck-server needs root privilege because it binds to a low numbered
|
||||
port (443). Alternatively you can follow https://superuser.com/a/892391 to avoid granting ck-server root privilege
|
||||
unnecessarily.
|
||||
#### If you want to add more users
|
||||
1. Run ck-server -u to generate a new UID
|
||||
2. On your client, run `ck-client -a -c <path-to-ckclient.json>` to enter admin mode
|
||||
3. Input as prompted, that is your ip:port of the server and your AdminUID. Enter 4 to create a new user.
|
||||
4. Enter the the newly generated UID, enter SessionsCap (maximum amount of concurrent sessions a user can have), UpRate and DownRate (in bytes/s), UpCredit and DownCredit (in bytes) and ExpiryTime (as a unix epoch)
|
||||
5. Give your **public** key and the newly generated UID to the new user
|
||||
|
||||
#### To add users
|
||||
|
||||
##### Unrestricted users
|
||||
|
||||
Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.json`
|
||||
|
||||
##### Users subject to bandwidth and credit controls
|
||||
|
||||
0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db`
|
||||
in `DatabasePath` (Cloak will create this file for you if it didn't already exist).
|
||||
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
|
||||
enter admin mode
|
||||
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data
|
||||
entered into this site are processed between your browser and the Cloak API endpoint you specified. Alternatively you
|
||||
can download the repo at https://github.com/cbeuw/Cloak-panel and open `index.html` in a browser. No web server is
|
||||
required).
|
||||
3. Type in `127.0.0.1:<the port you entered in step 1>` as the API Base, and click `List`.
|
||||
4. You can add in more users by clicking the `+` panel
|
||||
|
||||
Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start
|
||||
ck-server.
|
||||
|
||||
### Client
|
||||
Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start ck-server.
|
||||
|
||||
### Instructions for clients
|
||||
**Android client is available here: https://github.com/cbeuw/Cloak-android**
|
||||
|
||||
0. Install the underlying proxy client corresponding to what the server has.
|
||||
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
|
||||
2. Obtain the public key and your UID from the administrator of your server
|
||||
3. Copy `example_config/ckclient.json` into a location of your choice. Enter the `UID` and `PublicKey` you have
|
||||
obtained. Set `ProxyMethod` to match exactly the corresponding entry in `ProxyBook` on the server end
|
||||
4. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
|
||||
Run `ck-client -c <path to ckclient.json> -s <ip of your server>`
|
||||
0. Install and configure a version of shadowsocks client that supports plugins (such as shadowsocks-libev and shadowsocks-windows)
|
||||
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo. If you wish to build it, make sure you fetch the dependencies using `go get github.com/boltdb/bolt`, `go get github.com/juju/ratelimit` and `go get golang.org/x/crypto/curve25519`
|
||||
2. Obtain the public key and your UID (or the AdminUID, if you are the server admin) from the administrator of your server
|
||||
3. Put the public key and the UID you obtained into config/ckclient.json
|
||||
4. Configure your shadowsocks client with your server information. The field `plugin` should be the path to ck-server binary and `plugin_opts` should be the path to ckclient.json
|
||||
|
||||
## Support me
|
||||
|
||||
If you find this project useful, you can visit my [merch store](https://www.redbubble.com/people/cbeuw/explore);
|
||||
alternatively you can donate directly to me
|
||||
If you find this project useful, donations are greatly appreciated!
|
||||
|
||||
[](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,276 @@
|
|||
// +build !android
|
||||
|
||||
package main
|
||||
|
||||
// TODO: rewrite this. Think of another way of admin control
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/client/TLS"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
UID []byte
|
||||
// ALL of the following fields have to be accessed atomically
|
||||
SessionsCap uint32
|
||||
UpRate int64
|
||||
DownRate int64
|
||||
UpCredit int64
|
||||
DownCredit int64
|
||||
ExpiryTime int64
|
||||
}
|
||||
|
||||
type administrator struct {
|
||||
adminConn net.Conn
|
||||
adminUID []byte
|
||||
}
|
||||
|
||||
func adminPrompt(sta *client.State) error {
|
||||
a, err := adminHandshake(sta)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
fmt.Println(`1 listActiveUsers none []uids
|
||||
2 listAllUsers none []userinfo
|
||||
3 getUserInfo uid userinfo
|
||||
4 addNewUser userinfo ok
|
||||
5 delUser uid ok
|
||||
6 syncMemFromDB uid ok
|
||||
|
||||
7 setSessionsCap uid cap ok
|
||||
8 setUpRate uid rate ok
|
||||
9 setDownRate uid rate ok
|
||||
10 setUpCredit uid credit ok
|
||||
11 setDownCredit uid credit ok
|
||||
12 setExpiryTime uid time ok
|
||||
13 addUpCredit uid delta ok
|
||||
14 addDownCredit uid delta ok`)
|
||||
buf := make([]byte, 16000)
|
||||
for {
|
||||
req, err := a.getRequest()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
a.adminConn.Write(req)
|
||||
n, err := a.adminConn.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := a.checkAndDecrypt(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(string(resp))
|
||||
}
|
||||
}
|
||||
|
||||
func adminHandshake(sta *client.State) (*administrator, error) {
|
||||
fmt.Println("Enter the ip:port of your server")
|
||||
var addr string
|
||||
fmt.Scanln(&addr)
|
||||
fmt.Println("Enter the admin UID")
|
||||
var b64AdminUID string
|
||||
fmt.Scanln(&b64AdminUID)
|
||||
adminUID, err := base64.StdEncoding.DecodeString(b64AdminUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sta.UID = adminUID
|
||||
|
||||
remoteConn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientHello := TLS.ComposeInitHandshake(sta)
|
||||
_, err = remoteConn.Write(clientHello)
|
||||
|
||||
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
|
||||
discardBuf := make([]byte, 1024)
|
||||
for c := 0; c < 3; c++ {
|
||||
_, err = util.ReadTLS(remoteConn, discardBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
reply := TLS.ComposeReply()
|
||||
_, err = remoteConn.Write(reply)
|
||||
a := &administrator{remoteConn, adminUID}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *administrator) getRequest() (req []byte, err error) {
|
||||
promptUID := func() []byte {
|
||||
fmt.Println("Enter UID")
|
||||
var b64UID string
|
||||
fmt.Scanln(&b64UID)
|
||||
ret, _ := base64.StdEncoding.DecodeString(b64UID)
|
||||
return ret
|
||||
}
|
||||
|
||||
promptInt64 := func(name string) []byte {
|
||||
fmt.Println("Enter New " + name)
|
||||
var val int64
|
||||
fmt.Scanln(&val)
|
||||
ret := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ret, uint64(val))
|
||||
return ret
|
||||
}
|
||||
promptUint32 := func(name string) []byte {
|
||||
fmt.Println("Enter New " + name)
|
||||
var val uint32
|
||||
fmt.Scanln(&val)
|
||||
ret := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(ret, val)
|
||||
return ret
|
||||
}
|
||||
|
||||
fmt.Println("Select your command")
|
||||
var cmd string
|
||||
fmt.Scanln(&cmd)
|
||||
switch cmd {
|
||||
case "1":
|
||||
req = a.request([]byte{0x01})
|
||||
case "2":
|
||||
req = a.request([]byte{0x02})
|
||||
case "3":
|
||||
UID := promptUID()
|
||||
req = a.request(append([]byte{0x03}, UID...))
|
||||
case "4":
|
||||
var uinfo UserInfo
|
||||
var b64UID string
|
||||
fmt.Printf("UID:")
|
||||
fmt.Scanln(&b64UID)
|
||||
UID, _ := base64.StdEncoding.DecodeString(b64UID)
|
||||
uinfo.UID = UID
|
||||
fmt.Printf("SessionsCap:")
|
||||
fmt.Scanf("%d", &uinfo.SessionsCap)
|
||||
fmt.Printf("UpRate:")
|
||||
fmt.Scanf("%d", &uinfo.UpRate)
|
||||
fmt.Printf("DownRate:")
|
||||
fmt.Scanf("%d", &uinfo.DownRate)
|
||||
fmt.Printf("UpCredit:")
|
||||
fmt.Scanf("%d", &uinfo.UpCredit)
|
||||
fmt.Printf("DownCredit:")
|
||||
fmt.Scanf("%d", &uinfo.DownCredit)
|
||||
fmt.Printf("ExpiryTime:")
|
||||
fmt.Scanf("%d", &uinfo.ExpiryTime)
|
||||
marshed, _ := json.Marshal(uinfo)
|
||||
req = a.request(append([]byte{0x04}, marshed...))
|
||||
case "5":
|
||||
UID := promptUID()
|
||||
fmt.Println("Are you sure to delete this user? y/n")
|
||||
var ans string
|
||||
fmt.Scanln(&ans)
|
||||
if ans != "y" && ans != "Y" {
|
||||
return
|
||||
}
|
||||
req = a.request(append([]byte{0x05}, UID...))
|
||||
case "6":
|
||||
UID := promptUID()
|
||||
req = a.request(append([]byte{0x06}, UID...))
|
||||
case "7":
|
||||
arg := make([]byte, 36)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptUint32("SessionsCap"))
|
||||
req = a.request(append([]byte{0x07}, arg...))
|
||||
case "8":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("UpRate"))
|
||||
req = a.request(append([]byte{0x08}, arg...))
|
||||
case "9":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("DownRate"))
|
||||
req = a.request(append([]byte{0x09}, arg...))
|
||||
case "10":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("UpCredit"))
|
||||
req = a.request(append([]byte{0x0a}, arg...))
|
||||
case "11":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("DownCredit"))
|
||||
req = a.request(append([]byte{0x0b}, arg...))
|
||||
case "12":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("ExpiryTime"))
|
||||
req = a.request(append([]byte{0x0c}, arg...))
|
||||
case "13":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("UpCredit to add"))
|
||||
req = a.request(append([]byte{0x0d}, arg...))
|
||||
case "14":
|
||||
arg := make([]byte, 40)
|
||||
copy(arg, promptUID())
|
||||
copy(arg[32:], promptInt64("DownCredit to add"))
|
||||
req = a.request(append([]byte{0x0e}, arg...))
|
||||
default:
|
||||
return nil, errors.New("Unreconised cmd")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// protocol: 0[TLS record layer 5 bytes]5[IV 16 bytes]21[data][hmac 32 bytes]
|
||||
func (a *administrator) request(data []byte) []byte {
|
||||
dataLen := len(data)
|
||||
|
||||
buf := make([]byte, 5+16+dataLen+32)
|
||||
buf[0] = 0x17
|
||||
buf[1] = 0x03
|
||||
buf[2] = 0x03
|
||||
binary.BigEndian.PutUint16(buf[3:5], uint16(16+dataLen+32))
|
||||
|
||||
rand.Read(buf[5:21]) //iv
|
||||
copy(buf[21:], data)
|
||||
block, _ := aes.NewCipher(a.adminUID[0:16])
|
||||
stream := cipher.NewCTR(block, buf[5:21])
|
||||
stream.XORKeyStream(buf[21:21+dataLen], buf[21:21+dataLen])
|
||||
|
||||
mac := hmac.New(sha256.New, a.adminUID[16:32])
|
||||
mac.Write(buf[5 : 21+dataLen])
|
||||
copy(buf[21+dataLen:], mac.Sum(nil))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
var ErrInvalidMac = errors.New("Mac mismatch")
|
||||
|
||||
func (a *administrator) checkAndDecrypt(data []byte) ([]byte, error) {
|
||||
macIndex := len(data) - 32
|
||||
mac := hmac.New(sha256.New, a.adminUID[16:32])
|
||||
mac.Write(data[5:macIndex])
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(data[macIndex:], expected) {
|
||||
return nil, ErrInvalidMac
|
||||
}
|
||||
|
||||
iv := data[5:21]
|
||||
ret := data[21:macIndex]
|
||||
block, _ := aes.NewCipher(a.adminUID[0:16])
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(ret, ret)
|
||||
return ret, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
// +build android
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/cbeuw/Cloak/internal/client"
|
||||
|
||||
func adminPrompt(sta *client.State) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,69 +1,160 @@
|
|||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/client/TLS"
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
var version string
|
||||
|
||||
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
|
||||
// The maximum size of TLS message will be 16396+12. 12 because of the stream header
|
||||
// 16408 is the max TLS message size on Firefox
|
||||
buf := make([]byte, 16396)
|
||||
for {
|
||||
i, err := io.ReadAtLeast(src, buf, 1)
|
||||
if err != nil {
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
i, err = dst.Write(buf[:i])
|
||||
if err != nil {
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This establishes a connection with ckserver and performs a handshake
|
||||
func makeRemoteConn(sta *client.State) (net.Conn, error) {
|
||||
|
||||
// For android
|
||||
d := net.Dialer{Control: protector}
|
||||
|
||||
clientHello := TLS.ComposeInitHandshake(sta)
|
||||
connectingIP := sta.SS_REMOTE_HOST
|
||||
if net.ParseIP(connectingIP).To4() == nil {
|
||||
// IPv6 needs square brackets
|
||||
connectingIP = "[" + connectingIP + "]"
|
||||
}
|
||||
remoteConn, err := d.Dial("tcp", connectingIP+":"+sta.SS_REMOTE_PORT)
|
||||
if err != nil {
|
||||
log.Printf("Connecting to remote: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
_, err = remoteConn.Write(clientHello)
|
||||
if err != nil {
|
||||
log.Printf("Sending ClientHello: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
|
||||
discardBuf := make([]byte, 1024)
|
||||
for c := 0; c < 3; c++ {
|
||||
_, err = util.ReadTLS(remoteConn, discardBuf)
|
||||
if err != nil {
|
||||
log.Printf("Reading discarded message %v: %v\n", c, err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
reply := TLS.ComposeReply()
|
||||
_, err = remoteConn.Write(reply)
|
||||
if err != nil {
|
||||
log.Printf("Sending reply to remote: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return remoteConn, nil
|
||||
|
||||
}
|
||||
|
||||
func makeSession(sta *client.State) *mux.Session {
|
||||
log.Println("Attemtping to start a new session")
|
||||
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
|
||||
// sessionID is limited to its UID.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
sessionID := rand.Uint32()
|
||||
sta.SetSessionID(sessionID)
|
||||
var UNLIMITED_DOWN int64 = 1e15
|
||||
var UNLIMITED_UP int64 = 1e15
|
||||
valve := mux.MakeValve(1e12, 1e12, &UNLIMITED_DOWN, &UNLIMITED_UP)
|
||||
obfs := mux.MakeObfs(sta.UID)
|
||||
deobfs := mux.MakeDeobfs(sta.UID)
|
||||
sesh := mux.MakeSession(sessionID, valve, obfs, deobfs, util.ReadTLS)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < sta.NumConn; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
makeconn:
|
||||
conn, err := makeRemoteConn(sta)
|
||||
if err != nil {
|
||||
log.Printf("Failed to establish new connections to remote: %v\n", err)
|
||||
time.Sleep(time.Second * 3)
|
||||
goto makeconn
|
||||
}
|
||||
sesh.AddConnection(conn)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
log.Printf("Session %v established", sessionID)
|
||||
return sesh
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Should be 127.0.0.1 to listen to a proxy client on this machine
|
||||
// Should be 127.0.0.1 to listen to ss-local on this machine
|
||||
var localHost string
|
||||
// port used by proxy clients to communicate with cloak client
|
||||
// server_port in ss config, ss sends data on loopback using this port
|
||||
var localPort string
|
||||
// The ip of the proxy server
|
||||
var remoteHost string
|
||||
// The proxy port,should be 443
|
||||
var remotePort string
|
||||
var proxyMethod string
|
||||
var udp bool
|
||||
var config string
|
||||
var b64AdminUID string
|
||||
var vpnMode bool
|
||||
var tcpFastOpen bool
|
||||
var pluginOpts string
|
||||
isAdmin := new(bool)
|
||||
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
log_init()
|
||||
|
||||
ssPluginMode := os.Getenv("SS_LOCAL_HOST") != ""
|
||||
|
||||
verbosity := flag.String("verbosity", "info", "verbosity level")
|
||||
if ssPluginMode {
|
||||
config = os.Getenv("SS_PLUGIN_OPTIONS")
|
||||
flag.BoolVar(&vpnMode, "V", false, "ignored.")
|
||||
flag.BoolVar(&tcpFastOpen, "fast-open", false, "ignored.")
|
||||
flag.Parse() // for verbosity only
|
||||
if os.Getenv("SS_LOCAL_HOST") != "" {
|
||||
localHost = os.Getenv("SS_LOCAL_HOST")
|
||||
localPort = os.Getenv("SS_LOCAL_PORT")
|
||||
remoteHost = os.Getenv("SS_REMOTE_HOST")
|
||||
remotePort = os.Getenv("SS_REMOTE_PORT")
|
||||
pluginOpts = os.Getenv("SS_PLUGIN_OPTIONS")
|
||||
} else {
|
||||
flag.StringVar(&localHost, "i", "127.0.0.1", "localHost: Cloak listens to proxy clients on this ip")
|
||||
flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port")
|
||||
localHost = "127.0.0.1"
|
||||
flag.StringVar(&localPort, "l", "", "localPort: same as server_port in ss config, the plugin listens to SS using this")
|
||||
flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server")
|
||||
flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443")
|
||||
flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol")
|
||||
flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options separated with semicolons")
|
||||
flag.StringVar(&proxyMethod, "proxy", "", "proxy: the proxy method's name. It must match exactly with the corresponding entry in server's ProxyBook")
|
||||
flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api")
|
||||
flag.StringVar(&pluginOpts, "c", "ckclient.json", "pluginOpts: path to ckclient.json or options seperated with semicolons")
|
||||
askVersion := flag.Bool("v", false, "Print the version number")
|
||||
isAdmin = flag.Bool("a", false, "Admin mode")
|
||||
printUsage := flag.Bool("h", false, "Print this message")
|
||||
|
||||
// commandline arguments overrides json
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *askVersion {
|
||||
fmt.Printf("ck-client %s", version)
|
||||
fmt.Printf("ck-client %s\n", version)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -72,135 +163,83 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
log.Info("Starting standalone mode")
|
||||
log.Println("Starting standalone mode")
|
||||
}
|
||||
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
})
|
||||
lvl, err := log.ParseLevel(*verbosity)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if *isAdmin {
|
||||
sta := client.InitState("", "", "", "", time.Now)
|
||||
err := sta.ParseConfig(pluginOpts)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = adminPrompt(sta)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
log.SetLevel(lvl)
|
||||
|
||||
rawConfig, err := client.ParseConfig(config)
|
||||
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
|
||||
err := sta.ParseConfig(pluginOpts)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if ssPluginMode {
|
||||
if rawConfig.ProxyMethod == "" {
|
||||
rawConfig.ProxyMethod = "shadowsocks"
|
||||
if sta.SS_LOCAL_PORT == "" {
|
||||
log.Fatal("Must specify localPort")
|
||||
}
|
||||
if sta.SS_REMOTE_HOST == "" {
|
||||
log.Fatal("Must specify remoteHost")
|
||||
}
|
||||
if sta.TicketTimeHint == 0 {
|
||||
log.Fatal("TicketTimeHint cannot be empty or 0")
|
||||
}
|
||||
listeningIP := sta.SS_LOCAL_HOST
|
||||
if net.ParseIP(listeningIP).To4() == nil {
|
||||
// IPv6 needs square brackets
|
||||
listeningIP = "[" + listeningIP + "]"
|
||||
}
|
||||
listener, err := net.Listen("tcp", listeningIP+":"+sta.SS_LOCAL_PORT)
|
||||
log.Println("Listening for ss on " + listeningIP + ":" + sta.SS_LOCAL_PORT)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var sesh *mux.Session
|
||||
|
||||
for {
|
||||
ssConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
// json takes precedence over environment variables
|
||||
// i.e. if json field isn't empty, use that
|
||||
if rawConfig.RemoteHost == "" {
|
||||
rawConfig.RemoteHost = os.Getenv("SS_REMOTE_HOST")
|
||||
if sesh == nil || sesh.IsBroken() {
|
||||
sesh = makeSession(sta)
|
||||
}
|
||||
if rawConfig.RemotePort == "" {
|
||||
rawConfig.RemotePort = os.Getenv("SS_REMOTE_PORT")
|
||||
}
|
||||
if rawConfig.LocalHost == "" {
|
||||
rawConfig.LocalHost = os.Getenv("SS_LOCAL_HOST")
|
||||
}
|
||||
if rawConfig.LocalPort == "" {
|
||||
rawConfig.LocalPort = os.Getenv("SS_LOCAL_PORT")
|
||||
}
|
||||
} else {
|
||||
// commandline argument takes precedence over json
|
||||
// if commandline argument is set, use commandline
|
||||
flag.Visit(func(f *flag.Flag) {
|
||||
// manually set ones
|
||||
switch f.Name {
|
||||
case "i":
|
||||
rawConfig.LocalHost = localHost
|
||||
case "l":
|
||||
rawConfig.LocalPort = localPort
|
||||
case "s":
|
||||
rawConfig.RemoteHost = remoteHost
|
||||
case "p":
|
||||
rawConfig.RemotePort = remotePort
|
||||
case "u":
|
||||
rawConfig.UDP = udp
|
||||
case "proxy":
|
||||
rawConfig.ProxyMethod = proxyMethod
|
||||
go func() {
|
||||
data := make([]byte, 10240)
|
||||
i, err := io.ReadAtLeast(ssConn, data, 1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
ssConn.Close()
|
||||
return
|
||||
}
|
||||
})
|
||||
// ones with default values
|
||||
if rawConfig.LocalHost == "" {
|
||||
rawConfig.LocalHost = localHost
|
||||
}
|
||||
if rawConfig.LocalPort == "" {
|
||||
rawConfig.LocalPort = localPort
|
||||
}
|
||||
if rawConfig.RemotePort == "" {
|
||||
rawConfig.RemotePort = remotePort
|
||||
}
|
||||
stream, err := sesh.OpenStream()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
ssConn.Close()
|
||||
return
|
||||
}
|
||||
_, err = stream.Write(data[:i])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
ssConn.Close()
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
go pipe(ssConn, stream)
|
||||
pipe(stream, ssConn)
|
||||
}()
|
||||
}
|
||||
|
||||
localConfig, remoteConfig, authInfo, err := rawConfig.ProcessRawConfig(common.RealWorldState)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var adminUID []byte
|
||||
if b64AdminUID != "" {
|
||||
adminUID, err = base64.StdEncoding.DecodeString(b64AdminUID)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var seshMaker func() *mux.Session
|
||||
|
||||
d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive}
|
||||
|
||||
if adminUID != nil {
|
||||
log.Infof("API base is %v", localConfig.LocalAddr)
|
||||
authInfo.UID = adminUID
|
||||
authInfo.SessionId = 0
|
||||
remoteConfig.NumConn = 1
|
||||
|
||||
seshMaker = func() *mux.Session {
|
||||
return client.MakeSession(remoteConfig, authInfo, d)
|
||||
}
|
||||
} else {
|
||||
var network string
|
||||
if authInfo.Unordered {
|
||||
network = "UDP"
|
||||
} else {
|
||||
network = "TCP"
|
||||
}
|
||||
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
|
||||
seshMaker = func() *mux.Session {
|
||||
authInfo := authInfo // copy the struct because we are overwriting SessionId
|
||||
|
||||
randByte := make([]byte, 1)
|
||||
common.RandRead(authInfo.WorldState.Rand, randByte)
|
||||
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
|
||||
|
||||
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
|
||||
// sessionID is limited to its UID.
|
||||
quad := make([]byte, 4)
|
||||
common.RandRead(authInfo.WorldState.Rand, quad)
|
||||
authInfo.SessionId = binary.BigEndian.Uint32(quad)
|
||||
return client.MakeSession(remoteConfig, authInfo, d)
|
||||
}
|
||||
}
|
||||
|
||||
if authInfo.Unordered {
|
||||
acceptor := func() (*net.UDPConn, error) {
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", localConfig.LocalAddr)
|
||||
return net.ListenUDP("udp", udpAddr)
|
||||
}
|
||||
|
||||
client.RouteUDP(acceptor, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
|
||||
} else {
|
||||
listener, err := net.Listen("tcp", localConfig.LocalAddr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
//go:build !android
|
||||
// +build !android
|
||||
|
||||
package main
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build android
|
||||
// +build android
|
||||
|
||||
package main
|
||||
|
|
@ -29,10 +28,9 @@ import "C"
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"log"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -68,6 +66,8 @@ func lineLog(f *os.File, priority C.int) {
|
|||
|
||||
func log_init() {
|
||||
log.SetOutput(infoWriter{})
|
||||
// android logcat includes all of log.LstdFlags
|
||||
log.SetFlags(log.Flags() &^ log.LstdFlags)
|
||||
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
//go:build !android
|
||||
// +build !android
|
||||
|
||||
package main
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
//go:build android
|
||||
// +build android
|
||||
|
||||
package main
|
||||
|
||||
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go
|
||||
|
|
@ -15,60 +13,62 @@ package main
|
|||
#include <sys/un.h>
|
||||
#include <sys/uio.h>
|
||||
|
||||
#define ANCIL_FD_BUFFER(n) \
|
||||
struct { \
|
||||
struct cmsghdr h; \
|
||||
int fd[n]; \
|
||||
}
|
||||
#define ANCIL_FD_BUFFER(n) \
|
||||
struct { \
|
||||
struct cmsghdr h; \
|
||||
int fd[n]; \
|
||||
}
|
||||
|
||||
int ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds,
|
||||
void *buffer) {
|
||||
struct msghdr msghdr;
|
||||
char nothing = '!';
|
||||
struct iovec nothing_ptr;
|
||||
struct cmsghdr *cmsg;
|
||||
int i;
|
||||
int
|
||||
ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds, void *buffer)
|
||||
{
|
||||
struct msghdr msghdr;
|
||||
char nothing = '!';
|
||||
struct iovec nothing_ptr;
|
||||
struct cmsghdr *cmsg;
|
||||
int i;
|
||||
|
||||
nothing_ptr.iov_base = ¬hing;
|
||||
nothing_ptr.iov_len = 1;
|
||||
msghdr.msg_name = NULL;
|
||||
msghdr.msg_namelen = 0;
|
||||
msghdr.msg_iov = ¬hing_ptr;
|
||||
msghdr.msg_iovlen = 1;
|
||||
msghdr.msg_flags = 0;
|
||||
msghdr.msg_control = buffer;
|
||||
msghdr.msg_controllen = sizeof(struct cmsghdr) + sizeof(int) * n_fds;
|
||||
cmsg = CMSG_FIRSTHDR(&msghdr);
|
||||
cmsg->cmsg_len = msghdr.msg_controllen;
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
for (i = 0; i < n_fds; i++)
|
||||
((int *)CMSG_DATA(cmsg))[i] = fds[i];
|
||||
return (sendmsg(sock, &msghdr, 0) >= 0 ? 0 : -1);
|
||||
}
|
||||
nothing_ptr.iov_base = ¬hing;
|
||||
nothing_ptr.iov_len = 1;
|
||||
msghdr.msg_name = NULL;
|
||||
msghdr.msg_namelen = 0;
|
||||
msghdr.msg_iov = ¬hing_ptr;
|
||||
msghdr.msg_iovlen = 1;
|
||||
msghdr.msg_flags = 0;
|
||||
msghdr.msg_control = buffer;
|
||||
msghdr.msg_controllen = sizeof(struct cmsghdr) + sizeof(int) * n_fds;
|
||||
cmsg = CMSG_FIRSTHDR(&msghdr);
|
||||
cmsg->cmsg_len = msghdr.msg_controllen;
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
for(i = 0; i < n_fds; i++)
|
||||
((int *)CMSG_DATA(cmsg))[i] = fds[i];
|
||||
return(sendmsg(sock, &msghdr, 0) >= 0 ? 0 : -1);
|
||||
}
|
||||
|
||||
int ancil_send_fd(int sock, int fd) {
|
||||
ANCIL_FD_BUFFER(1) buffer;
|
||||
int
|
||||
ancil_send_fd(int sock, int fd)
|
||||
{
|
||||
ANCIL_FD_BUFFER(1) buffer;
|
||||
|
||||
return (ancil_send_fds_with_buffer(sock, &fd, 1, &buffer));
|
||||
}
|
||||
return(ancil_send_fds_with_buffer(sock, &fd, 1, &buffer));
|
||||
}
|
||||
|
||||
void set_timeout(int sock) {
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 3;
|
||||
tv.tv_usec = 0;
|
||||
setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv,
|
||||
sizeof(struct timeval));
|
||||
setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv,
|
||||
sizeof(struct timeval));
|
||||
}
|
||||
void
|
||||
set_timeout(int sock)
|
||||
{
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 3;
|
||||
tv.tv_usec = 0;
|
||||
setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval));
|
||||
setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, sizeof(struct timeval));
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// In Android, once an app starts the VpnService, all outgoing traffic are routed by the system
|
||||
|
|
|
|||
|
|
@ -1,199 +1,294 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/cbeuw/Cloak/internal/server"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
var version string
|
||||
|
||||
func resolveBindAddr(bindAddrs []string) ([]net.Addr, error) {
|
||||
var addrs []net.Addr
|
||||
for _, addr := range bindAddrs {
|
||||
bindAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
|
||||
// The maximum size of TLS message will be 16396+12. 12 because of the stream header
|
||||
// 16408 is the max TLS message size on Firefox
|
||||
buf := make([]byte, 16396)
|
||||
for {
|
||||
i, err := io.ReadAtLeast(src, buf, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
i, err = dst.Write(buf[:i])
|
||||
if err != nil {
|
||||
go dst.Close()
|
||||
go src.Close()
|
||||
return
|
||||
}
|
||||
addrs = append(addrs, bindAddr)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// parse what shadowsocks server wants us to bind and harmonise it with what's already in bindAddr from
|
||||
// our own config's BindAddr. This prevents duplicate bindings etc.
|
||||
func parseSSBindAddr(ssRemoteHost string, ssRemotePort string, ckBindAddr *[]net.Addr) error {
|
||||
var ssBind string
|
||||
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0
|
||||
v4nv6 := len(strings.Split(ssRemoteHost, "|")) == 2
|
||||
if v4nv6 {
|
||||
ssBind = ":" + ssRemotePort
|
||||
} else {
|
||||
ssBind = net.JoinHostPort(ssRemoteHost, ssRemotePort)
|
||||
}
|
||||
ssBindAddr, err := net.ResolveTCPAddr("tcp", ssBind)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to resolve bind address provided by SS: %v", err)
|
||||
func dispatchConnection(conn net.Conn, sta *server.State) {
|
||||
goWeb := func(data []byte) {
|
||||
webConn, err := net.Dial("tcp", sta.WebServerAddr)
|
||||
if err != nil {
|
||||
log.Printf("Making connection to redirection server: %v\n", err)
|
||||
return
|
||||
}
|
||||
webConn.Write(data)
|
||||
go pipe(webConn, conn)
|
||||
go pipe(conn, webConn)
|
||||
}
|
||||
|
||||
shouldAppend := true
|
||||
for i, addr := range *ckBindAddr {
|
||||
if addr.String() == ssBindAddr.String() {
|
||||
shouldAppend = false
|
||||
buf := make([]byte, 1500)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
i, err := io.ReadAtLeast(conn, buf, 1)
|
||||
if err != nil {
|
||||
go conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
data := buf[:i]
|
||||
ch, err := server.ParseClientHello(data)
|
||||
if err != nil {
|
||||
log.Printf("+1 non SS non (or malformed) TLS traffic from %v\n", conn.RemoteAddr())
|
||||
goWeb(data)
|
||||
return
|
||||
}
|
||||
|
||||
isSS, UID, sessionID := server.TouchStone(ch, sta)
|
||||
if !isSS {
|
||||
log.Printf("+1 non SS TLS traffic from %v\n", conn.RemoteAddr())
|
||||
goWeb(data)
|
||||
return
|
||||
}
|
||||
|
||||
finishHandshake := func() error {
|
||||
reply := server.ComposeReply(ch)
|
||||
_, err = conn.Write(reply)
|
||||
if err != nil {
|
||||
go conn.Close()
|
||||
return err
|
||||
}
|
||||
if addr.String() == ":"+ssRemotePort { // already listening on all interfaces
|
||||
shouldAppend = false
|
||||
}
|
||||
if addr.String() == "0.0.0.0:"+ssRemotePort || addr.String() == "[::]:"+ssRemotePort {
|
||||
// if config listens on one ip version but ss wants to listen on both,
|
||||
// listen on both
|
||||
if ssBindAddr.String() == ":"+ssRemotePort {
|
||||
shouldAppend = true
|
||||
(*ckBindAddr)[i] = ssBindAddr
|
||||
|
||||
// Two discarded messages: ChangeCipherSpec and Finished
|
||||
discardBuf := make([]byte, 1024)
|
||||
for c := 0; c < 2; c++ {
|
||||
_, err = util.ReadTLS(conn, discardBuf)
|
||||
if err != nil {
|
||||
go conn.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if shouldAppend {
|
||||
*ckBindAddr = append(*ckBindAddr, ssBindAddr)
|
||||
|
||||
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
|
||||
// added to the userinfo database. The distinction between going into the admin mode
|
||||
// and normal proxy mode is that sessionID needs == 0 for admin mode
|
||||
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
|
||||
err = finishHandshake()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
c := sta.Userpanel.MakeController(sta.AdminUID)
|
||||
for {
|
||||
n, err := conn.Read(data)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
resp, err := c.HandleRequest(data[:n])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
_, err = conn.Write(resp)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
|
||||
var user *usermanager.User
|
||||
if bytes.Equal(UID, sta.AdminUID) {
|
||||
user, err = sta.Userpanel.GetAndActivateAdminUser(UID)
|
||||
} else {
|
||||
user, err = sta.Userpanel.GetAndActivateUser(UID)
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
|
||||
goWeb(data)
|
||||
return
|
||||
}
|
||||
|
||||
err = finishHandshake()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
sesh, existing, err := user.GetSession(sessionID, mux.MakeObfs(UID), mux.MakeDeobfs(UID), util.ReadTLS)
|
||||
if err != nil {
|
||||
user.DelSession(sessionID)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
if existing {
|
||||
sesh.AddConnection(conn)
|
||||
return
|
||||
} else {
|
||||
log.Printf("New session from UID:%v, sessionID:%v\n", base64.StdEncoding.EncodeToString(UID), sessionID)
|
||||
sesh.AddConnection(conn)
|
||||
for {
|
||||
newStream, err := sesh.AcceptStream()
|
||||
if err != nil {
|
||||
if err == mux.ErrBrokenSession {
|
||||
log.Printf("Session closed for UID:%v, sessionID:%v\n", base64.StdEncoding.EncodeToString(UID), sessionID)
|
||||
user.DelSession(sessionID)
|
||||
return
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
ssIP := sta.SS_LOCAL_HOST
|
||||
if net.ParseIP(ssIP).To4() == nil {
|
||||
// IPv6 needs square brackets
|
||||
ssIP = "[" + ssIP + "]"
|
||||
}
|
||||
ssConn, err := net.Dial("tcp", ssIP+":"+sta.SS_LOCAL_PORT)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to ssserver: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go pipe(ssConn, newStream)
|
||||
go pipe(newStream, ssConn)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
var config string
|
||||
// Should be 127.0.0.1 to listen to ss-server on this machine
|
||||
var localHost string
|
||||
// server_port in ss config, same as remotePort in plugin mode
|
||||
var localPort string
|
||||
// server in ss config, the outbound listening ip
|
||||
var remoteHost string
|
||||
// Outbound listening ip, should be 443
|
||||
var remotePort string
|
||||
var pluginOpts string
|
||||
|
||||
var pluginMode bool
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
})
|
||||
|
||||
if os.Getenv("SS_LOCAL_HOST") != "" && os.Getenv("SS_LOCAL_PORT") != "" {
|
||||
pluginMode = true
|
||||
config = os.Getenv("SS_PLUGIN_OPTIONS")
|
||||
if os.Getenv("SS_LOCAL_HOST") != "" {
|
||||
localHost = os.Getenv("SS_LOCAL_HOST")
|
||||
localPort = os.Getenv("SS_LOCAL_PORT")
|
||||
remoteHost = os.Getenv("SS_REMOTE_HOST")
|
||||
remotePort = os.Getenv("SS_REMOTE_PORT")
|
||||
pluginOpts = os.Getenv("SS_PLUGIN_OPTIONS")
|
||||
} else {
|
||||
flag.StringVar(&config, "c", "server.json", "config: path to the configuration file or its content")
|
||||
localAddr := flag.String("r", "", "localAddr: the ip:port ss-server is listening on, set in Shadowsocks' configuration. If ss-server is running locally, it should be 127.0.0.1:some port")
|
||||
flag.StringVar(&remoteHost, "s", "0.0.0.0", "remoteHost: outbound listing ip, set to 0.0.0.0 to listen to everything")
|
||||
flag.StringVar(&remotePort, "p", "443", "remotePort: outbound listing port, should be 443")
|
||||
flag.StringVar(&pluginOpts, "c", "server.json", "pluginOpts: path to server.json or options seperated by semicolons")
|
||||
askVersion := flag.Bool("v", false, "Print the version number")
|
||||
printUsage := flag.Bool("h", false, "Print this message")
|
||||
|
||||
genUIDScript := flag.Bool("u", false, "Generate a UID to STDOUT")
|
||||
genKeyPairScript := flag.Bool("k", false, "Generate a pair of public and private key and output to STDOUT in the format of <public key>,<private key>")
|
||||
|
||||
genUIDHuman := flag.Bool("uid", false, "Generate and print out a UID")
|
||||
genKeyPairHuman := flag.Bool("key", false, "Generate and print out a public-private key pair")
|
||||
genUID := flag.Bool("u", false, "Generate a UID")
|
||||
genKeyPair := flag.Bool("k", false, "Generate a pair of public and private key, output in the format of pubkey,pvkey")
|
||||
|
||||
pprofAddr := flag.String("d", "", "debug use: ip:port to be listened by pprof profiler")
|
||||
verbosity := flag.String("verbosity", "info", "verbosity level")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *askVersion {
|
||||
fmt.Printf("ck-server %s", version)
|
||||
fmt.Printf("ck-server %s\n", version)
|
||||
return
|
||||
}
|
||||
if *printUsage {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
if *genUIDScript || *genUIDHuman {
|
||||
uid := generateUID()
|
||||
if *genUIDScript {
|
||||
fmt.Println(uid)
|
||||
} else {
|
||||
fmt.Printf("\x1B[35mYour UID is:\u001B[0m %s\n", uid)
|
||||
}
|
||||
if *genUID {
|
||||
fmt.Println(generateUID())
|
||||
return
|
||||
}
|
||||
if *genKeyPairScript || *genKeyPairHuman {
|
||||
if *genKeyPair {
|
||||
pub, pv := generateKeyPair()
|
||||
if *genKeyPairScript {
|
||||
fmt.Printf("%v,%v\n", pub, pv)
|
||||
} else {
|
||||
fmt.Printf("\x1B[36mYour PUBLIC key is:\x1B[0m %65s\n", pub)
|
||||
fmt.Printf("\x1B[33mYour PRIVATE key is (keep it secret):\x1B[0m %47s\n", pv)
|
||||
}
|
||||
fmt.Printf("%v,%v", pub, pv)
|
||||
return
|
||||
}
|
||||
|
||||
if *pprofAddr != "" {
|
||||
runtime.SetBlockProfileRate(5)
|
||||
go func() {
|
||||
log.Info(http.ListenAndServe(*pprofAddr, nil))
|
||||
}()
|
||||
log.Infof("pprof listening on %v", *pprofAddr)
|
||||
|
||||
startPprof(*pprofAddr)
|
||||
}
|
||||
|
||||
lvl, err := log.ParseLevel(*verbosity)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if *localAddr == "" {
|
||||
log.Fatal("Must specify localAddr")
|
||||
}
|
||||
log.SetLevel(lvl)
|
||||
|
||||
log.Infof("Starting standalone mode")
|
||||
localHost = strings.Split(*localAddr, ":")[0]
|
||||
localPort = strings.Split(*localAddr, ":")[1]
|
||||
log.Printf("Starting standalone mode, listening on %v:%v to ss at %v:%v\n", remoteHost, remotePort, localHost, localPort)
|
||||
}
|
||||
sta, _ := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
|
||||
|
||||
raw, err := server.ParseConfig(config)
|
||||
err := sta.ParseConfig(pluginOpts)
|
||||
if err != nil {
|
||||
log.Fatalf("Configuration file error: %v", err)
|
||||
}
|
||||
|
||||
bindAddr, err := resolveBindAddr(raw.BindAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to parse BindAddr: %v", err)
|
||||
if sta.AdminUID == nil {
|
||||
log.Fatalln("AdminUID cannot be empty!")
|
||||
}
|
||||
|
||||
// in case the user hasn't specified any local address to bind to, we listen on 443 and 80
|
||||
if !pluginMode && len(bindAddr) == 0 {
|
||||
https, _ := net.ResolveTCPAddr("tcp", ":443")
|
||||
http, _ := net.ResolveTCPAddr("tcp", ":80")
|
||||
bindAddr = []net.Addr{https, http}
|
||||
}
|
||||
go sta.UsedRandomCleaner()
|
||||
|
||||
// when cloak is started as a shadowsocks plugin, we parse the address ss-server
|
||||
// is listening on into ProxyBook, and we parse the list of bindAddr
|
||||
if pluginMode {
|
||||
ssLocalHost := os.Getenv("SS_LOCAL_HOST")
|
||||
ssLocalPort := os.Getenv("SS_LOCAL_PORT")
|
||||
raw.ProxyBook["shadowsocks"] = []string{"tcp", net.JoinHostPort(ssLocalHost, ssLocalPort)}
|
||||
|
||||
ssRemoteHost := os.Getenv("SS_REMOTE_HOST")
|
||||
ssRemotePort := os.Getenv("SS_REMOTE_PORT")
|
||||
err = parseSSBindAddr(ssRemoteHost, ssRemotePort, &bindAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to parse SS_REMOTE_HOST and SS_REMOTE_PORT: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sta, err := server.InitState(raw, common.RealWorldState)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to initialise server state: %v", err)
|
||||
}
|
||||
|
||||
listen := func(bindAddr net.Addr) {
|
||||
listener, err := net.Listen("tcp", bindAddr.String())
|
||||
log.Infof("Listening on %v", bindAddr)
|
||||
listen := func(addr, port string) {
|
||||
listener, err := net.Listen("tcp", addr+":"+port)
|
||||
log.Println("Listening on " + addr + ":" + port)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
server.Serve(listener, sta)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("%v", err)
|
||||
continue
|
||||
}
|
||||
go dispatchConnection(conn, sta)
|
||||
}
|
||||
}
|
||||
|
||||
for i, addr := range bindAddr {
|
||||
if i != len(bindAddr)-1 {
|
||||
go listen(addr)
|
||||
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0
|
||||
listeningIP := strings.Split(sta.SS_REMOTE_HOST, "|")
|
||||
for i, ip := range listeningIP {
|
||||
if net.ParseIP(ip).To4() == nil {
|
||||
// IPv6 needs square brackets
|
||||
ip = "[" + ip + "]"
|
||||
}
|
||||
|
||||
// The last listener must block main() because the program exits on main return.
|
||||
if i == len(listeningIP)-1 {
|
||||
listen(ip, sta.SS_REMOTE_PORT)
|
||||
} else {
|
||||
// we block the main goroutine here so it doesn't quit
|
||||
listen(addr)
|
||||
go listen(ip, sta.SS_REMOTE_PORT)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,136 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseBindAddr(t *testing.T) {
|
||||
t.Run("port only", func(t *testing.T) {
|
||||
addrs, err := resolveBindAddr([]string{":443"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ":443", addrs[0].String())
|
||||
})
|
||||
|
||||
t.Run("specific address", func(t *testing.T) {
|
||||
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "192.168.1.123:443", addrs[0].String())
|
||||
})
|
||||
|
||||
t.Run("ipv6", func(t *testing.T) {
|
||||
addrs, err := resolveBindAddr([]string{"[::]:443"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "[::]:443", addrs[0].String())
|
||||
})
|
||||
|
||||
t.Run("mixed", func(t *testing.T) {
|
||||
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ":80", addrs[0].String())
|
||||
assert.Equal(t, "[::]:443", addrs[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func assertSetEqual(t *testing.T, list1, list2 interface{}, msgAndArgs ...interface{}) (ok bool) {
|
||||
return assert.Subset(t, list1, list2, msgAndArgs) && assert.Subset(t, list2, list1, msgAndArgs)
|
||||
}
|
||||
|
||||
func TestParseSSBindAddr(t *testing.T) {
|
||||
testTable := []struct {
|
||||
name string
|
||||
ssRemoteHost string
|
||||
ssRemotePort string
|
||||
ckBindAddr []net.Addr
|
||||
expectedAddr []net.Addr
|
||||
}{
|
||||
{
|
||||
"ss only ipv4",
|
||||
"127.0.0.1",
|
||||
"443",
|
||||
[]net.Addr{},
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"ss only ipv6",
|
||||
"::",
|
||||
"443",
|
||||
[]net.Addr{},
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
//{
|
||||
// "ss only ipv4 and v6",
|
||||
// "::|127.0.0.1",
|
||||
// "443",
|
||||
// []net.Addr{},
|
||||
// []net.Addr{
|
||||
// &net.TCPAddr{
|
||||
// IP: net.ParseIP("::"),
|
||||
// Port: 443,
|
||||
// },
|
||||
// &net.TCPAddr{
|
||||
// IP: net.ParseIP("127.0.0.1"),
|
||||
// Port: 443,
|
||||
// },
|
||||
// },
|
||||
//},
|
||||
{
|
||||
"ss and existing agrees",
|
||||
"::",
|
||||
"443",
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"ss adds onto existing",
|
||||
"127.0.0.1",
|
||||
"80",
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: 443,
|
||||
},
|
||||
},
|
||||
[]net.Addr{
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: 443,
|
||||
},
|
||||
&net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testTable {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.NoError(t, parseSSBindAddr(test.ssRemoteHost, test.ssRemotePort, &test.ckBindAddr))
|
||||
assertSetEqual(t, test.ckBindAddr, test.expectedAddr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -3,20 +3,20 @@ package main
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
)
|
||||
|
||||
var b64 = base64.StdEncoding.EncodeToString
|
||||
|
||||
func generateUID() string {
|
||||
UID := make([]byte, 16)
|
||||
common.CryptoRandRead(UID)
|
||||
return base64.StdEncoding.EncodeToString(UID)
|
||||
UID := make([]byte, 32)
|
||||
rand.Read(UID)
|
||||
return b64(UID)
|
||||
}
|
||||
|
||||
func generateKeyPair() (string, string) {
|
||||
staticPv, staticPub, _ := ecdh.GenerateKey(rand.Reader)
|
||||
marshPub := ecdh.Marshal(staticPub)
|
||||
marshPv := staticPv.(*[32]byte)[:]
|
||||
return base64.StdEncoding.EncodeToString(marshPub), base64.StdEncoding.EncodeToString(marshPv)
|
||||
return b64(marshPub), b64(marshPv)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
// +build !pprof
|
||||
|
||||
package main
|
||||
|
||||
import "log"
|
||||
|
||||
func startPprof(x string) {
|
||||
log.Println("pprof not available in release builds to reduce binary size")
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
// +build pprof
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func startPprof(pprofAddr string) {
|
||||
runtime.SetBlockProfileRate(5)
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe(pprofAddr, nil))
|
||||
}()
|
||||
log.Println("pprof listening on " + pprofAddr)
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
coverage:
|
||||
status:
|
||||
project: off
|
||||
patch: off
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=",
|
||||
"PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=",
|
||||
"ServerName":"www.bing.com",
|
||||
"TicketTimeHint":3600,
|
||||
"NumConn":4,
|
||||
"MaskBrowser":"chrome"
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"WebServerAddr":"204.79.197.200:443",
|
||||
"PrivateKey":"EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=",
|
||||
"AdminUID":"ugDmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=",
|
||||
"DatabasePath":"userinfo.db",
|
||||
"BackupDirPath":""
|
||||
}
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
{
|
||||
"Transport": "direct",
|
||||
"ProxyMethod": "shadowsocks",
|
||||
"EncryptionMethod": "plain",
|
||||
"UID": "---Your UID here---",
|
||||
"PublicKey": "---Public key here---",
|
||||
"ServerName": "www.bing.com",
|
||||
"NumConn": 4,
|
||||
"BrowserSig": "chrome",
|
||||
"StreamTimeout": 300
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
{
|
||||
"ProxyBook": {
|
||||
"shadowsocks": [
|
||||
"tcp",
|
||||
"127.0.0.1:8388"
|
||||
],
|
||||
"openvpn": [
|
||||
"udp",
|
||||
"127.0.0.1:8389"
|
||||
],
|
||||
"tor": [
|
||||
"tcp",
|
||||
"127.0.0.1:9001"
|
||||
]
|
||||
},
|
||||
"BindAddr": [
|
||||
":443",
|
||||
":80"
|
||||
],
|
||||
"BypassUID": [
|
||||
"---Bypass UID here---"
|
||||
],
|
||||
"RedirAddr": "cloudflare.com",
|
||||
"PrivateKey": "---Private key here---",
|
||||
"AdminUID": "---Admin UID here (optional)---",
|
||||
"DatabasePath": "userinfo.db"
|
||||
}
|
||||
32
go.mod
32
go.mod
|
|
@ -1,30 +1,10 @@
|
|||
module github.com/cbeuw/Cloak
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/juju/ratelimit v1.0.2
|
||||
github.com/refraction-networking/utls v1.8.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.etcd.io/bbolt v1.4.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/boltdb/bolt v1.3.1
|
||||
github.com/juju/ratelimit v1.0.1
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b
|
||||
golang.org/x/sys v0.0.0-20190124100055-b90733256f2e // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
)
|
||||
|
|
|
|||
72
go.sum
72
go.sum
|
|
@ -1,61 +1,15 @@
|
|||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3 h1:LRxW8pdmWmyhoNh+TxUjxsAinGtCsVGjsl3xg6zoRSs=
|
||||
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3/go.mod h1:6jR2SzckGv8hIIS9zWJ160mzGVVOYp4AXZMDtacL6LE=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI=
|
||||
github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4=
|
||||
github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps=
|
||||
github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY=
|
||||
github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.6.6 h1:igFsYBUJPYM8Rno9xUuDoM5GQrVEqY4llzEXOkL43Ig=
|
||||
github.com/refraction-networking/utls v1.6.6/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0=
|
||||
github.com/refraction-networking/utls v1.7.0/go.mod h1:lV0Gwc1/Fi+HYH8hOtgFRdHfKo4FKSn6+FdyOz9hRms=
|
||||
github.com/refraction-networking/utls v1.7.3 h1:L0WRhHY7Oq1T0zkdzVZMR6zWZv+sXbHB9zcuvsAEqCo=
|
||||
github.com/refraction-networking/utls v1.7.3/go.mod h1:TUhh27RHMGtQvjQq+RyO11P6ZNQNBb3N0v7wsEjKAIQ=
|
||||
github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE=
|
||||
github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk=
|
||||
go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b h1:Elez2XeF2p9uyVj0yEUDqQ56NFcDtcBNkYP7yv8YbUE=
|
||||
golang.org/x/crypto v0.0.0-20190123085648-057139ce5d2b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/sys v0.0.0-20190124100055-b90733256f2e h1:3GIlrlVLfkoipSReOMNAgApI0ajnalyLa/EZHHca/XI=
|
||||
golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
|
|||
|
|
@ -1,168 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
utls "github.com/refraction-networking/utls"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const appDataMaxLength = 16401
|
||||
|
||||
type clientHelloFields struct {
|
||||
random []byte
|
||||
sessionId []byte
|
||||
x25519KeyShare []byte
|
||||
serverName string
|
||||
}
|
||||
|
||||
type browser int
|
||||
|
||||
const (
|
||||
chrome = iota
|
||||
firefox
|
||||
safari
|
||||
)
|
||||
|
||||
type DirectTLS struct {
|
||||
*common.TLSConn
|
||||
browser browser
|
||||
}
|
||||
|
||||
var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
|
||||
|
||||
func randomServerName() string {
|
||||
/*
|
||||
Copyright: Proton AG
|
||||
https://github.com/ProtonVPN/wireguard-go/commit/bcf344b39b213c1f32147851af0d2a8da9266883
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
*/
|
||||
charNum := int('z') - int('a') + 1
|
||||
size := 3 + common.RandInt(10)
|
||||
name := make([]byte, size)
|
||||
for i := range name {
|
||||
name[i] = byte(int('a') + common.RandInt(charNum))
|
||||
}
|
||||
return string(name) + "." + common.RandItem(topLevelDomains)
|
||||
}
|
||||
|
||||
func buildClientHello(browser browser, fields clientHelloFields) ([]byte, error) {
|
||||
// We don't use utls to handle connections (as it'll attempt a real TLS negotiation)
|
||||
// We only want it to build the ClientHello locally
|
||||
fakeConn := net.TCPConn{}
|
||||
var helloID utls.ClientHelloID
|
||||
switch browser {
|
||||
case chrome:
|
||||
helloID = utls.HelloChrome_Auto
|
||||
case firefox:
|
||||
helloID = utls.HelloFirefox_Auto
|
||||
case safari:
|
||||
helloID = utls.HelloSafari_Auto
|
||||
}
|
||||
|
||||
uclient := utls.UClient(&fakeConn, &utls.Config{ServerName: fields.serverName}, helloID)
|
||||
if err := uclient.BuildHandshakeState(); err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
if err := uclient.SetClientRandom(fields.random); err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
uclient.HandshakeState.Hello.SessionId = make([]byte, 32)
|
||||
copy(uclient.HandshakeState.Hello.SessionId, fields.sessionId)
|
||||
|
||||
// Find the X25519 key share and overwrite it
|
||||
var extIndex int
|
||||
var keyShareIndex int
|
||||
for i, ext := range uclient.Extensions {
|
||||
ext, ok := ext.(*utls.KeyShareExtension)
|
||||
if ok {
|
||||
extIndex = i
|
||||
for j, keyShare := range ext.KeyShares {
|
||||
if keyShare.Group == utls.X25519 {
|
||||
keyShareIndex = j
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
copy(uclient.Extensions[extIndex].(*utls.KeyShareExtension).KeyShares[keyShareIndex].Data, fields.x25519KeyShare)
|
||||
|
||||
if err := uclient.BuildHandshakeState(); err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
return uclient.HandshakeState.Hello.Raw, nil
|
||||
}
|
||||
|
||||
// Handshake handles the TLS handshake for a given conn and returns the sessionKey
|
||||
// if the server proceed with Cloak authentication
|
||||
func (tls *DirectTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) {
|
||||
payload, sharedSecret := makeAuthenticationPayload(authInfo)
|
||||
|
||||
fields := clientHelloFields{
|
||||
random: payload.randPubKey[:],
|
||||
sessionId: payload.ciphertextWithTag[0:32],
|
||||
x25519KeyShare: payload.ciphertextWithTag[32:64],
|
||||
serverName: authInfo.MockDomain,
|
||||
}
|
||||
|
||||
if strings.EqualFold(fields.serverName, "random") {
|
||||
fields.serverName = randomServerName()
|
||||
}
|
||||
|
||||
var ch []byte
|
||||
ch, err = buildClientHello(tls.browser, fields)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
chWithRecordLayer := common.AddRecordLayer(ch, common.Handshake, common.VersionTLS11)
|
||||
_, err = rawConn.Write(chWithRecordLayer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Trace("client hello sent successfully")
|
||||
tls.TLSConn = common.NewTLSConn(rawConn)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
log.Trace("waiting for ServerHello")
|
||||
_, err = tls.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
encrypted := append(buf[6:38], buf[84:116]...)
|
||||
nonce := encrypted[0:12]
|
||||
ciphertextWithTag := encrypted[12:60]
|
||||
sessionKeySlice, err := common.AESGCMDecrypt(nonce, sharedSecret[:], ciphertextWithTag)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
copy(sessionKey[:], sessionKeySlice)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
// ChangeCipherSpec and EncryptedCert (in the format of application data)
|
||||
_, err = tls.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return sessionKey, nil
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
package TLS
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
"time"
|
||||
)
|
||||
|
||||
type browser interface {
|
||||
composeExtensions()
|
||||
composeClientHello()
|
||||
}
|
||||
|
||||
func makeServerName(sta *client.State) []byte {
|
||||
serverName := sta.ServerName
|
||||
serverNameListLength := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(serverNameListLength, uint16(len(serverName)+3))
|
||||
serverNameType := []byte{0x00} // host_name
|
||||
serverNameLength := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(serverNameLength, uint16(len(serverName)))
|
||||
ret := make([]byte, 2+1+2+len(serverName))
|
||||
copy(ret[0:2], serverNameListLength)
|
||||
copy(ret[2:3], serverNameType)
|
||||
copy(ret[3:5], serverNameLength)
|
||||
copy(ret[5:], serverName)
|
||||
return ret
|
||||
}
|
||||
|
||||
func makeNullBytes(length int) []byte {
|
||||
ret := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
ret[i] = 0x00
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// addExtensionRecord, add type, length to extension data
|
||||
func addExtRec(typ []byte, data []byte) []byte {
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(data)))
|
||||
ret := make([]byte, 2+2+len(data))
|
||||
copy(ret[0:2], typ)
|
||||
copy(ret[2:4], length)
|
||||
copy(ret[4:], data)
|
||||
return ret
|
||||
}
|
||||
|
||||
// ComposeInitHandshake composes ClientHello with record layer
|
||||
func ComposeInitHandshake(sta *client.State) []byte {
|
||||
var ch []byte
|
||||
switch sta.MaskBrowser {
|
||||
case "chrome":
|
||||
ch = (&chrome{}).composeClientHello(sta)
|
||||
case "firefox":
|
||||
ch = (&firefox{}).composeClientHello(sta)
|
||||
default:
|
||||
panic("Unsupported browser:" + sta.MaskBrowser)
|
||||
}
|
||||
return util.AddRecordLayer(ch, []byte{0x16}, []byte{0x03, 0x01})
|
||||
}
|
||||
|
||||
// ComposeReply composes RL+ChangeCipherSpec+RL+Finished
|
||||
func ComposeReply() []byte {
|
||||
TLS12 := []byte{0x03, 0x03}
|
||||
ccsBytes := util.AddRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
||||
finished := util.PsudoRandBytes(40, time.Now().UnixNano())
|
||||
fBytes := util.AddRecordLayer(finished, []byte{0x16}, TLS12)
|
||||
return append(ccsBytes, fBytes...)
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
// Chrome 64
|
||||
|
||||
package TLS
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
type chrome struct {
|
||||
browser
|
||||
}
|
||||
|
||||
func (c *chrome) composeExtensions(sta *client.State) []byte {
|
||||
// see https://tools.ietf.org/html/draft-davidben-tls-grease-01
|
||||
// This is exclusive to chrome.
|
||||
makeGREASE := func() []byte {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
sixteenth := rand.Intn(16)
|
||||
monoGREASE := byte(sixteenth*16 + 0xA)
|
||||
doubleGREASE := []byte{monoGREASE, monoGREASE}
|
||||
return doubleGREASE
|
||||
}
|
||||
|
||||
makeSupportedGroups := func() []byte {
|
||||
suppGroupListLen := []byte{0x00, 0x08}
|
||||
ret := make([]byte, 2+8)
|
||||
copy(ret[0:2], suppGroupListLen)
|
||||
copy(ret[2:4], makeGREASE())
|
||||
copy(ret[4:], []byte{0x00, 0x1d, 0x00, 0x17, 0x00, 0x18})
|
||||
return ret
|
||||
}
|
||||
|
||||
var ext [14][]byte
|
||||
ext[0] = addExtRec(makeGREASE(), nil) // First GREASE
|
||||
ext[1] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
|
||||
ext[2] = addExtRec([]byte{0x00, 0x00}, makeServerName(sta)) // server name indication
|
||||
ext[3] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
|
||||
ext[4] = addExtRec([]byte{0x00, 0x23}, client.MakeSessionTicket(sta)) // Session tickets
|
||||
sigAlgo, _ := hex.DecodeString("0012040308040401050308050501080606010201")
|
||||
ext[5] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
|
||||
ext[6] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
|
||||
ext[7] = addExtRec([]byte{0x00, 0x12}, nil) // signed cert timestamp
|
||||
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
|
||||
ext[8] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
|
||||
ext[9] = addExtRec([]byte{0x75, 0x50}, nil) // channel id
|
||||
ext[10] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
|
||||
ext[11] = addExtRec([]byte{0x00, 0x0a}, makeSupportedGroups()) // supported groups
|
||||
ext[12] = addExtRec(makeGREASE(), []byte{0x00}) // Last GREASE
|
||||
ext[13] = addExtRec([]byte{0x00, 0x15}, makeNullBytes(110-len(ext[2]))) // padding
|
||||
var ret []byte
|
||||
for i := 0; i < 14; i++ {
|
||||
ret = append(ret, ext[i]...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *chrome) composeClientHello(sta *client.State) []byte {
|
||||
var clientHello [12][]byte
|
||||
clientHello[0] = []byte{0x01} // handshake type
|
||||
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
|
||||
clientHello[2] = []byte{0x03, 0x03} // client version
|
||||
clientHello[3] = client.MakeRandomField(sta) // random
|
||||
clientHello[4] = []byte{0x20} // session id length 32
|
||||
clientHello[5] = util.PsudoRandBytes(32, sta.Now().UnixNano()) // session id
|
||||
clientHello[6] = []byte{0x00, 0x1c} // cipher suites length 28
|
||||
cipherSuites, _ := hex.DecodeString("2a2ac02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a")
|
||||
clientHello[7] = cipherSuites // cipher suites
|
||||
clientHello[8] = []byte{0x01} // compression methods length 1
|
||||
clientHello[9] = []byte{0x00} // compression methods
|
||||
clientHello[10] = []byte{0x01, 0x97} // extensions length 407
|
||||
clientHello[11] = c.composeExtensions(sta) // extensions
|
||||
var ret []byte
|
||||
for i := 0; i < 12; i++ {
|
||||
ret = append(ret, clientHello[i]...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
// Firefox 58
|
||||
package TLS
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
type firefox struct {
|
||||
browser
|
||||
}
|
||||
|
||||
func (f *firefox) composeExtensions(sta *client.State) []byte {
|
||||
var ext [10][]byte
|
||||
ext[0] = addExtRec([]byte{0x00, 0x00}, makeServerName(sta)) // server name indication
|
||||
ext[1] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
|
||||
ext[2] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
|
||||
suppGroup, _ := hex.DecodeString("0008001d001700180019")
|
||||
ext[3] = addExtRec([]byte{0x00, 0x0a}, suppGroup) // supported groups
|
||||
ext[4] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
|
||||
ext[5] = addExtRec([]byte{0x00, 0x23}, client.MakeSessionTicket(sta)) // Session tickets
|
||||
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
|
||||
ext[6] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
|
||||
ext[7] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
|
||||
sigAlgo, _ := hex.DecodeString("001604030503060308040805080604010501060102030201")
|
||||
ext[8] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
|
||||
ext[9] = addExtRec([]byte{0x00, 0x15}, makeNullBytes(121-len(ext[0]))) // padding
|
||||
var ret []byte
|
||||
for i := 0; i < 10; i++ {
|
||||
ret = append(ret, ext[i]...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (f *firefox) composeClientHello(sta *client.State) []byte {
|
||||
var clientHello [12][]byte
|
||||
clientHello[0] = []byte{0x01} // handshake type
|
||||
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
|
||||
clientHello[2] = []byte{0x03, 0x03} // client version
|
||||
clientHello[3] = client.MakeRandomField(sta) // random
|
||||
clientHello[4] = []byte{0x20} // session id length 32
|
||||
clientHello[5] = util.PsudoRandBytes(32, sta.Now().UnixNano()) // session id
|
||||
clientHello[6] = []byte{0x00, 0x1e} // cipher suites length 28
|
||||
cipherSuites, _ := hex.DecodeString("c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a")
|
||||
clientHello[7] = cipherSuites // cipher suites
|
||||
clientHello[8] = []byte{0x01} // compression methods length 1
|
||||
clientHello[9] = []byte{0x00} // compression methods
|
||||
clientHello[10] = []byte{0x01, 0x95} // extensions length 405
|
||||
clientHello[11] = f.composeExtensions(sta) // extensions
|
||||
var ret []byte
|
||||
for i := 0; i < 12; i++ {
|
||||
ret = append(ret, clientHello[i]...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
@ -1,56 +1,72 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
UNORDERED_FLAG = 0x01 // 0000 0001
|
||||
)
|
||||
|
||||
type authenticationPayload struct {
|
||||
randPubKey [32]byte
|
||||
ciphertextWithTag [64]byte
|
||||
type keyPair struct {
|
||||
crypto.PrivateKey
|
||||
crypto.PublicKey
|
||||
}
|
||||
|
||||
// makeAuthenticationPayload generates the ephemeral key pair, calculates the shared secret, and then compose and
|
||||
// encrypt the authenticationPayload
|
||||
func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sharedSecret [32]byte) {
|
||||
/*
|
||||
Authentication data:
|
||||
+----------+----------------+---------------------+-------------+--------------+--------+------------+
|
||||
| _UID_ | _Proxy Method_ | _Encryption Method_ | _Timestamp_ | _Session Id_ | _Flag_ | _reserved_ |
|
||||
+----------+----------------+---------------------+-------------+--------------+--------+------------+
|
||||
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
|
||||
+----------+----------------+---------------------+-------------+--------------+--------+------------+
|
||||
*/
|
||||
ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand)
|
||||
if err != nil {
|
||||
log.Panicf("failed to generate ephemeral key pair: %v", err)
|
||||
}
|
||||
copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
|
||||
|
||||
plaintext := make([]byte, 48)
|
||||
copy(plaintext, authInfo.UID)
|
||||
copy(plaintext[16:28], authInfo.ProxyMethod)
|
||||
plaintext[28] = authInfo.EncryptionMethod
|
||||
binary.BigEndian.PutUint64(plaintext[29:37], uint64(authInfo.WorldState.Now().UTC().Unix()))
|
||||
binary.BigEndian.PutUint32(plaintext[37:41], authInfo.SessionId)
|
||||
|
||||
if authInfo.Unordered {
|
||||
plaintext[41] |= UNORDERED_FLAG
|
||||
}
|
||||
|
||||
secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)
|
||||
if err != nil {
|
||||
log.Panicf("error in generating shared secret: %v", err)
|
||||
}
|
||||
copy(sharedSecret[:], secret)
|
||||
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext)
|
||||
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:])
|
||||
return
|
||||
func MakeRandomField(sta *State) []byte {
|
||||
t := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(t, uint64(sta.Now().Unix()/(12*60*60)))
|
||||
rdm := make([]byte, 16)
|
||||
io.ReadFull(rand.Reader, rdm)
|
||||
preHash := make([]byte, 56)
|
||||
copy(preHash[0:32], sta.UID)
|
||||
copy(preHash[32:40], t)
|
||||
copy(preHash[40:56], rdm)
|
||||
h := sha256.New()
|
||||
h.Write(preHash)
|
||||
ret := make([]byte, 32)
|
||||
copy(ret[0:16], rdm)
|
||||
copy(ret[16:32], h.Sum(nil)[0:16])
|
||||
return ret
|
||||
}
|
||||
|
||||
func MakeSessionTicket(sta *State) []byte {
|
||||
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted UID+sessionID 36 bytes][padding 124 bytes]
|
||||
// The first 16 bytes of the marshalled ephemeral public key is used as the IV
|
||||
// for encrypting the UID
|
||||
tthInterval := sta.Now().Unix() / int64(sta.TicketTimeHint)
|
||||
sta.keyPairsM.Lock()
|
||||
ephKP := sta.keyPairs[tthInterval]
|
||||
if ephKP == nil {
|
||||
ephPv, ephPub, _ := ecdh.GenerateKey(rand.Reader)
|
||||
ephKP = &keyPair{
|
||||
ephPv,
|
||||
ephPub,
|
||||
}
|
||||
sta.keyPairs[tthInterval] = ephKP
|
||||
}
|
||||
sta.keyPairsM.Unlock()
|
||||
ticket := make([]byte, 192)
|
||||
copy(ticket[0:32], ecdh.Marshal(ephKP.PublicKey))
|
||||
key := ecdh.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub)
|
||||
plainUIDsID := make([]byte, 36)
|
||||
copy(plainUIDsID, sta.UID)
|
||||
binary.BigEndian.PutUint32(plainUIDsID[32:36], sta.sessionID)
|
||||
cipherUIDsID := util.AESEncrypt(ticket[0:16], key, plainUIDsID)
|
||||
copy(ticket[32:68], cipherUIDsID)
|
||||
// The purpose of adding sessionID is that, the generated padding of sessionTicket needs to be unpredictable.
|
||||
// As shown in auth.go, the padding is generated by a psudo random generator. The seed
|
||||
// needs to be the same for each TicketTimeHint interval. However the value of epoch/TicketTimeHint
|
||||
// is public knowledge, so is the psudo random algorithm used by math/rand. Therefore not only
|
||||
// can the firewall tell that the padding is generated in this specific way, this padding is identical
|
||||
// for all ckclients in the same TicketTimeHint interval. This will expose us.
|
||||
//
|
||||
// With the sessionID value generated at startup of ckclient and used as a part of the seed, the
|
||||
// sessionTicket is still identical for each TicketTimeHint interval, but others won't be able to know
|
||||
// how it was generated. It will also be different for each client.
|
||||
copy(ticket[68:192], util.PsudoRandBytes(124, tthInterval+int64(sta.sessionID)))
|
||||
return ticket
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,72 +2,101 @@ package client
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
prand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
)
|
||||
|
||||
func TestMakeAuthenticationPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
authInfo AuthInfo
|
||||
expPayload authenticationPayload
|
||||
expSecret [32]byte
|
||||
}{
|
||||
{
|
||||
AuthInfo{
|
||||
Unordered: false,
|
||||
SessionId: 3421516597,
|
||||
UID: []byte{
|
||||
0x4c, 0xd8, 0xcc, 0x15, 0x60, 0x0d, 0x7e,
|
||||
0xb6, 0x81, 0x31, 0xfd, 0x80, 0x97, 0x67, 0x37, 0x46},
|
||||
ServerPubKey: &[32]byte{
|
||||
0x21, 0x8a, 0x14, 0xce, 0x49, 0x5e, 0xfd, 0x3f,
|
||||
0xe4, 0xae, 0x21, 0x3e, 0x51, 0xf7, 0x66, 0xec,
|
||||
0x01, 0xd0, 0xb4, 0x87, 0x86, 0x9c, 0x15, 0x9b,
|
||||
0x86, 0x19, 0x53, 0x6e, 0x60, 0xe9, 0x51, 0x42},
|
||||
ProxyMethod: "shadowsocks",
|
||||
EncryptionMethod: multiplex.EncryptionMethodPlain,
|
||||
MockDomain: "d2jkinvisak5y9.cloudfront.net",
|
||||
WorldState: common.WorldState{
|
||||
Rand: bytes.NewBuffer([]byte{
|
||||
0xf1, 0x1e, 0x42, 0xe1, 0x84, 0x22, 0x07, 0xc5,
|
||||
0xc3, 0x5c, 0x0f, 0x7b, 0x01, 0xf3, 0x65, 0x2d,
|
||||
0xd7, 0x9b, 0xad, 0xb0, 0xb2, 0x77, 0xa2, 0x06,
|
||||
0x6b, 0x78, 0x1b, 0x74, 0x1f, 0x43, 0xc9, 0x80}),
|
||||
Now: func() time.Time { return time.Unix(1579908372, 0) },
|
||||
},
|
||||
},
|
||||
authenticationPayload{
|
||||
randPubKey: [32]byte{
|
||||
0xee, 0x9e, 0x41, 0x4e, 0xb3, 0x3b, 0x85, 0x03,
|
||||
0x6d, 0x85, 0xba, 0x30, 0x11, 0x31, 0x10, 0x24,
|
||||
0x4f, 0x7b, 0xd5, 0x38, 0x50, 0x0f, 0xf2, 0x4d,
|
||||
0xa3, 0xdf, 0xba, 0x76, 0x0a, 0xe9, 0x19, 0x19},
|
||||
ciphertextWithTag: [64]byte{
|
||||
0x71, 0xb1, 0x6c, 0x5a, 0x60, 0x46, 0x90, 0x12,
|
||||
0x36, 0x3b, 0x1b, 0xc4, 0x79, 0x3c, 0xab, 0xdd,
|
||||
0x5a, 0x53, 0xc5, 0xed, 0xaf, 0xdb, 0x10, 0x98,
|
||||
0x83, 0x96, 0x81, 0xa6, 0xfc, 0xa2, 0x1e, 0xb0,
|
||||
0x89, 0xb2, 0x29, 0x71, 0x7e, 0x45, 0x97, 0x54,
|
||||
0x11, 0x7d, 0x9b, 0x92, 0xbb, 0xd6, 0xce, 0x37,
|
||||
0x3b, 0xb8, 0x8b, 0xfb, 0xb6, 0x40, 0xf0, 0x2c,
|
||||
0x6c, 0x55, 0xb9, 0xfc, 0x5d, 0x34, 0x89, 0x41},
|
||||
},
|
||||
[32]byte{
|
||||
0xc7, 0xc6, 0x9b, 0xbe, 0xec, 0xf8, 0x35, 0x55,
|
||||
0x67, 0x20, 0xcd, 0xeb, 0x74, 0x16, 0xc5, 0x60,
|
||||
0xee, 0x9d, 0x63, 0x1a, 0x44, 0xc5, 0x09, 0xf6,
|
||||
0xe0, 0x24, 0xad, 0xd2, 0x10, 0xe3, 0x4a, 0x11},
|
||||
},
|
||||
func TestMakeSessionTicket(t *testing.T) {
|
||||
UID, _ := hex.DecodeString("26a8e88bcd7c64a69ca051740851d22a6818de2fddafc00882331f1c5a8b866c")
|
||||
staticPv, staticPub, _ := ecdh.GenerateKey(rand.Reader)
|
||||
mockSta := &State{
|
||||
Now: time.Now,
|
||||
sessionID: 42,
|
||||
UID: UID,
|
||||
staticPub: staticPub,
|
||||
keyPairs: make(map[int64]*keyPair),
|
||||
TicketTimeHint: 3600,
|
||||
}
|
||||
for _, tc := range tests {
|
||||
func() {
|
||||
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
|
||||
assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
|
||||
assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
|
||||
}()
|
||||
|
||||
ticket := MakeSessionTicket(mockSta)
|
||||
|
||||
// verification
|
||||
ephPub, _ := ecdh.Unmarshal(ticket[0:32])
|
||||
key := ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||
|
||||
// aes decrypt
|
||||
UIDsID := make([]byte, len(ticket[32:68]))
|
||||
copy(UIDsID, ticket[32:68]) // Because XORKeyStream is inplace, but we don't want the input to be changed
|
||||
block, _ := aes.NewCipher(key)
|
||||
stream := cipher.NewCTR(block, ticket[0:16])
|
||||
stream.XORKeyStream(UIDsID, UIDsID)
|
||||
|
||||
decryUID := UIDsID[0:32]
|
||||
decrySessionID := binary.BigEndian.Uint32(UIDsID[32:36])
|
||||
|
||||
// check padding
|
||||
tthInterval := mockSta.Now().Unix() / int64(mockSta.TicketTimeHint)
|
||||
r := prand.New(prand.NewSource(tthInterval + int64(mockSta.sessionID)))
|
||||
pad := make([]byte, 124)
|
||||
r.Read(pad)
|
||||
|
||||
if !bytes.Equal(mockSta.UID, decryUID) {
|
||||
t.Error(
|
||||
"For", "UID",
|
||||
"expecting", fmt.Sprintf("%x", mockSta.UID),
|
||||
"got", fmt.Sprintf("%x", decryUID),
|
||||
)
|
||||
}
|
||||
if mockSta.sessionID != decrySessionID {
|
||||
t.Error(
|
||||
"For", "sessionID",
|
||||
"expecting", mockSta.sessionID,
|
||||
"got", decrySessionID,
|
||||
)
|
||||
}
|
||||
if !bytes.Equal(pad, ticket[68:]) {
|
||||
t.Error(
|
||||
"For", "Padding",
|
||||
"expecting", fmt.Sprintf("%x", pad),
|
||||
"got", fmt.Sprintf("%x", ticket[68:]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRandomField(t *testing.T) {
|
||||
UID, _ := hex.DecodeString("26a8e88bcd7c64a69ca051740851d22a6818de2fddafc00882331f1c5a8b866c")
|
||||
mockSta := &State{
|
||||
Now: time.Now,
|
||||
UID: UID,
|
||||
}
|
||||
random := MakeRandomField(mockSta)
|
||||
|
||||
// verification
|
||||
tb := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(tb, uint64(time.Now().Unix()/(12*60*60)))
|
||||
rdm := random[0:16]
|
||||
preHash := make([]byte, 56)
|
||||
copy(preHash[0:32], UID)
|
||||
copy(preHash[32:40], tb)
|
||||
copy(preHash[40:56], rdm)
|
||||
h := sha256.New()
|
||||
h.Write(preHash)
|
||||
exp := h.Sum(nil)[0:16]
|
||||
if !bytes.Equal(exp, random[16:32]) {
|
||||
t.Error(
|
||||
"For", "Random",
|
||||
"expecting", fmt.Sprintf("%x", exp),
|
||||
"got", fmt.Sprintf("%x", random[16:32]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// On different invocations to MakeSession, authInfo.SessionId MUST be different
|
||||
func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *mux.Session {
|
||||
log.Info("Attempting to start a new session")
|
||||
|
||||
connsCh := make(chan net.Conn, connConfig.NumConn)
|
||||
var _sessionKey atomic.Value
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < connConfig.NumConn; i++ {
|
||||
wg.Add(1)
|
||||
transportConfig := connConfig.Transport
|
||||
go func() {
|
||||
makeconn:
|
||||
transportConn := transportConfig.CreateTransport()
|
||||
remoteConn, err := dialer.Dial("tcp", connConfig.RemoteAddr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to establish new connections to remote: %v", err)
|
||||
// TODO increase the interval if failed multiple times
|
||||
time.Sleep(time.Second * 3)
|
||||
goto makeconn
|
||||
}
|
||||
|
||||
sk, err := transportConn.Handshake(remoteConn, authInfo)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to prepare connection to remote: %v", err)
|
||||
transportConn.Close()
|
||||
|
||||
// In Cloak v2.11.0, we've updated uTLS version and subsequently increased the first packet size for chrome above 1500
|
||||
// https://github.com/cbeuw/Cloak/pull/306#issuecomment-2862728738. As a backwards compatibility feature, if we fail
|
||||
// to connect using chrome signature, retry with firefox which has a smaller packet size.
|
||||
if transportConfig.mode == "direct" && transportConfig.browser == chrome {
|
||||
transportConfig.browser = firefox
|
||||
log.Warnf("failed to connect with chrome signature, falling back to retry with firefox")
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
goto makeconn
|
||||
}
|
||||
// sessionKey given by each connection should be identical
|
||||
_sessionKey.Store(sk)
|
||||
connsCh <- transportConn
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
log.Debug("All underlying connections established")
|
||||
|
||||
sessionKey := _sessionKey.Load().([32]byte)
|
||||
obfuscator, err := mux.MakeObfuscator(authInfo.EncryptionMethod, sessionKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
seshConfig := mux.SessionConfig{
|
||||
Singleplex: connConfig.Singleplex,
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: authInfo.Unordered,
|
||||
MsgOnWireSizeLimit: appDataMaxLength,
|
||||
}
|
||||
sesh := mux.MakeSession(authInfo.SessionId, seshConfig)
|
||||
|
||||
for i := 0; i < connConfig.NumConn; i++ {
|
||||
conn := <-connsCh
|
||||
sesh.AddConnection(conn)
|
||||
}
|
||||
|
||||
log.Infof("Session %v established", authInfo.SessionId)
|
||||
return sesh
|
||||
}
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
|
||||
var sesh *mux.Session
|
||||
localConn, err := bindFunc()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
streams := make(map[string]*mux.Stream)
|
||||
var streamsMutex sync.Mutex
|
||||
|
||||
data := make([]byte, 8192)
|
||||
for {
|
||||
i, addr, err := localConn.ReadFrom(data)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read first packet from proxy client: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !singleplex && (sesh == nil || sesh.IsClosed()) {
|
||||
sesh = newSeshFunc()
|
||||
}
|
||||
|
||||
streamsMutex.Lock()
|
||||
stream, ok := streams[addr.String()]
|
||||
if !ok {
|
||||
if singleplex {
|
||||
sesh = newSeshFunc()
|
||||
}
|
||||
|
||||
stream, err = sesh.OpenStream()
|
||||
if err != nil {
|
||||
if singleplex {
|
||||
sesh.Close()
|
||||
}
|
||||
log.Errorf("Failed to open stream: %v", err)
|
||||
streamsMutex.Unlock()
|
||||
continue
|
||||
}
|
||||
streams[addr.String()] = stream
|
||||
streamsMutex.Unlock()
|
||||
|
||||
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
|
||||
|
||||
proxyAddr := addr
|
||||
go func(stream *mux.Stream, localConn *net.UDPConn) {
|
||||
buf := make([]byte, 8192)
|
||||
for {
|
||||
n, err := stream.Read(buf)
|
||||
if err != nil {
|
||||
log.Tracef("copying stream to proxy client: %v", err)
|
||||
break
|
||||
}
|
||||
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
|
||||
|
||||
_, err = localConn.WriteTo(buf[:n], proxyAddr)
|
||||
if err != nil {
|
||||
log.Tracef("copying stream to proxy client: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
streamsMutex.Lock()
|
||||
delete(streams, addr.String())
|
||||
streamsMutex.Unlock()
|
||||
stream.Close()
|
||||
return
|
||||
}(stream, localConn)
|
||||
} else {
|
||||
streamsMutex.Unlock()
|
||||
}
|
||||
|
||||
_, err = stream.Write(data[:i])
|
||||
if err != nil {
|
||||
log.Tracef("copying proxy client to stream: %v", err)
|
||||
streamsMutex.Lock()
|
||||
delete(streams, addr.String())
|
||||
streamsMutex.Unlock()
|
||||
stream.Close()
|
||||
continue
|
||||
}
|
||||
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
|
||||
var sesh *mux.Session
|
||||
for {
|
||||
localConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
continue
|
||||
}
|
||||
if !singleplex && (sesh == nil || sesh.IsClosed()) {
|
||||
sesh = newSeshFunc()
|
||||
}
|
||||
go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) {
|
||||
if singleplex {
|
||||
sesh = newSeshFunc()
|
||||
}
|
||||
|
||||
data := make([]byte, 10240)
|
||||
_ = localConn.SetReadDeadline(time.Now().Add(streamTimeout))
|
||||
i, err := io.ReadAtLeast(localConn, data, 1)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read first packet from proxy client: %v", err)
|
||||
localConn.Close()
|
||||
return
|
||||
}
|
||||
var zeroTime time.Time
|
||||
_ = localConn.SetReadDeadline(zeroTime)
|
||||
|
||||
stream, err := sesh.OpenStream()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open stream: %v", err)
|
||||
localConn.Close()
|
||||
if singleplex {
|
||||
sesh.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = stream.Write(data[:i])
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to stream: %v", err)
|
||||
localConn.Close()
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if _, err := common.Copy(localConn, stream); err != nil {
|
||||
log.Tracef("copying stream to proxy client: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err = common.Copy(stream, localConn); err != nil {
|
||||
log.Tracef("copying proxy client to stream: %v", err)
|
||||
}
|
||||
}(sesh, localConn, streamTimeout)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,88 +2,68 @@ package client
|
|||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
// RawConfig represents the fields in the config json file
|
||||
// nullable means if it's empty, a default value will be chosen in ProcessRawConfig
|
||||
// jsonOptional means if the json's empty, its value will be set from environment variables or commandline args
|
||||
// but it mustn't be empty when ProcessRawConfig is called
|
||||
type RawConfig struct {
|
||||
ServerName string
|
||||
ProxyMethod string
|
||||
EncryptionMethod string
|
||||
UID []byte
|
||||
PublicKey []byte
|
||||
NumConn int
|
||||
LocalHost string // jsonOptional
|
||||
LocalPort string // jsonOptional
|
||||
RemoteHost string // jsonOptional
|
||||
RemotePort string // jsonOptional
|
||||
AlternativeNames []string // jsonOptional
|
||||
// defaults set in ProcessRawConfig
|
||||
UDP bool // nullable
|
||||
BrowserSig string // nullable
|
||||
Transport string // nullable
|
||||
CDNOriginHost string // nullable
|
||||
CDNWsUrlPath string // nullable
|
||||
StreamTimeout int // nullable
|
||||
KeepAlive int // nullable
|
||||
type rawConfig struct {
|
||||
ServerName string
|
||||
UID string
|
||||
PublicKey string
|
||||
TicketTimeHint int
|
||||
MaskBrowser string
|
||||
NumConn int
|
||||
}
|
||||
|
||||
type RemoteConnConfig struct {
|
||||
Singleplex bool
|
||||
NumConn int
|
||||
KeepAlive time.Duration
|
||||
RemoteAddr string
|
||||
Transport TransportConfig
|
||||
// State stores global variables
|
||||
type State struct {
|
||||
SS_LOCAL_HOST string
|
||||
SS_LOCAL_PORT string
|
||||
SS_REMOTE_HOST string
|
||||
SS_REMOTE_PORT string
|
||||
|
||||
Now func() time.Time
|
||||
sessionID uint32
|
||||
UID []byte
|
||||
staticPub crypto.PublicKey
|
||||
keyPairsM sync.RWMutex
|
||||
keyPairs map[int64]*keyPair
|
||||
|
||||
TicketTimeHint int
|
||||
ServerName string
|
||||
MaskBrowser string
|
||||
NumConn int
|
||||
}
|
||||
|
||||
type LocalConnConfig struct {
|
||||
LocalAddr string
|
||||
Timeout time.Duration
|
||||
MockDomainList []string
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State {
|
||||
ret := &State{
|
||||
SS_LOCAL_HOST: localHost,
|
||||
SS_LOCAL_PORT: localPort,
|
||||
SS_REMOTE_HOST: remoteHost,
|
||||
SS_REMOTE_PORT: remotePort,
|
||||
Now: nowFunc,
|
||||
}
|
||||
ret.keyPairs = make(map[int64]*keyPair)
|
||||
return ret
|
||||
}
|
||||
|
||||
type AuthInfo struct {
|
||||
UID []byte
|
||||
SessionId uint32
|
||||
ProxyMethod string
|
||||
EncryptionMethod byte
|
||||
Unordered bool
|
||||
ServerPubKey crypto.PublicKey
|
||||
MockDomain string
|
||||
WorldState common.WorldState
|
||||
}
|
||||
func (sta *State) SetSessionID(id uint32) { sta.sessionID = id }
|
||||
|
||||
// semi-colon separated value. This is for Android plugin options
|
||||
func ssvToJson(ssv string) (ret []byte) {
|
||||
elem := func(val string, lst []string) bool {
|
||||
for _, v := range lst {
|
||||
if val == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
unescape := func(s string) string {
|
||||
r := strings.Replace(s, `\\`, `\`, -1)
|
||||
r = strings.Replace(r, `\=`, `=`, -1)
|
||||
r = strings.Replace(r, `\;`, `;`, -1)
|
||||
return r
|
||||
}
|
||||
unquoted := []string{"NumConn", "StreamTimeout", "KeepAlive", "UDP"}
|
||||
lines := strings.Split(unescape(ssv), ";")
|
||||
ret = []byte("{")
|
||||
for _, ln := range lines {
|
||||
|
|
@ -91,29 +71,11 @@ func ssvToJson(ssv string) (ret []byte) {
|
|||
break
|
||||
}
|
||||
sp := strings.SplitN(ln, "=", 2)
|
||||
if len(sp) < 2 {
|
||||
log.Errorf("Malformed config option: %v", ln)
|
||||
continue
|
||||
}
|
||||
key := sp[0]
|
||||
value := sp[1]
|
||||
if strings.HasPrefix(key, "AlternativeNames") {
|
||||
switch strings.Contains(value, ",") {
|
||||
case true:
|
||||
domains := strings.Split(value, ",")
|
||||
for index, domain := range domains {
|
||||
domains[index] = `"` + domain + `"`
|
||||
}
|
||||
value = strings.Join(domains, ",")
|
||||
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
|
||||
case false:
|
||||
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// JSON doesn't like quotation marks around int and bool
|
||||
// This is extremely ugly but it's still better than writing a tokeniser
|
||||
if elem(key, unquoted) {
|
||||
// JSON doesn't like quotation marks around int
|
||||
// Yes this is extremely ugly but it's still better than writing a tokeniser
|
||||
if key == "TicketTimeHint" || key == "NumConn" {
|
||||
ret = append(ret, []byte(`"`+key+`":`+value+`,`)...)
|
||||
} else {
|
||||
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
|
||||
|
|
@ -124,156 +86,40 @@ func ssvToJson(ssv string) (ret []byte) {
|
|||
return ret
|
||||
}
|
||||
|
||||
func ParseConfig(conf string) (raw *RawConfig, err error) {
|
||||
// ParseConfig parses the config (either a path to json or Android config) into a State variable
|
||||
func (sta *State) ParseConfig(conf string) (err error) {
|
||||
var content []byte
|
||||
// Checking if it's a path to json or a ssv string
|
||||
if strings.Contains(conf, ";") && strings.Contains(conf, "=") {
|
||||
content = ssvToJson(conf)
|
||||
} else {
|
||||
content, err = ioutil.ReadFile(conf)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
raw = new(RawConfig)
|
||||
err = json.Unmarshal(content, &raw)
|
||||
var preParse rawConfig
|
||||
err = json.Unmarshal(content, &preParse)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
sta.ServerName = preParse.ServerName
|
||||
sta.TicketTimeHint = preParse.TicketTimeHint
|
||||
sta.MaskBrowser = preParse.MaskBrowser
|
||||
sta.NumConn = preParse.NumConn
|
||||
uid, err := base64.StdEncoding.DecodeString(preParse.UID)
|
||||
if err != nil {
|
||||
return errors.New("Failed to parse UID: " + err.Error())
|
||||
}
|
||||
sta.UID = uid
|
||||
|
||||
func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
|
||||
nullErr := func(field string) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
|
||||
err = fmt.Errorf("%v cannot be empty", field)
|
||||
return
|
||||
pubBytes, err := base64.StdEncoding.DecodeString(preParse.PublicKey)
|
||||
if err != nil {
|
||||
return errors.New("Failed to parse Public key: " + err.Error())
|
||||
}
|
||||
|
||||
auth.UID = raw.UID
|
||||
auth.Unordered = raw.UDP
|
||||
if raw.ServerName == "" {
|
||||
return nullErr("ServerName")
|
||||
}
|
||||
auth.MockDomain = raw.ServerName
|
||||
|
||||
var filteredAlternativeNames []string
|
||||
for _, alternativeName := range raw.AlternativeNames {
|
||||
if len(alternativeName) > 0 {
|
||||
filteredAlternativeNames = append(filteredAlternativeNames, alternativeName)
|
||||
}
|
||||
}
|
||||
raw.AlternativeNames = filteredAlternativeNames
|
||||
|
||||
local.MockDomainList = raw.AlternativeNames
|
||||
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
|
||||
if raw.ProxyMethod == "" {
|
||||
return nullErr("ServerName")
|
||||
}
|
||||
auth.ProxyMethod = raw.ProxyMethod
|
||||
if len(raw.UID) == 0 {
|
||||
return nullErr("UID")
|
||||
}
|
||||
|
||||
// static public key
|
||||
if len(raw.PublicKey) == 0 {
|
||||
return nullErr("PublicKey")
|
||||
}
|
||||
pub, ok := ecdh.Unmarshal(raw.PublicKey)
|
||||
pub, ok := ecdh.Unmarshal(pubBytes)
|
||||
if !ok {
|
||||
err = fmt.Errorf("failed to unmarshal Public key")
|
||||
return
|
||||
return errors.New("Failed to unmarshal Public key")
|
||||
}
|
||||
auth.ServerPubKey = pub
|
||||
auth.WorldState = worldState
|
||||
|
||||
// Encryption method
|
||||
switch strings.ToLower(raw.EncryptionMethod) {
|
||||
case "plain":
|
||||
auth.EncryptionMethod = mux.EncryptionMethodPlain
|
||||
case "aes-gcm", "aes-256-gcm":
|
||||
auth.EncryptionMethod = mux.EncryptionMethodAES256GCM
|
||||
case "aes-128-gcm":
|
||||
auth.EncryptionMethod = mux.EncryptionMethodAES128GCM
|
||||
case "chacha20-poly1305":
|
||||
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305
|
||||
default:
|
||||
err = fmt.Errorf("unknown encryption method %v", raw.EncryptionMethod)
|
||||
return
|
||||
}
|
||||
|
||||
if raw.RemoteHost == "" {
|
||||
return nullErr("RemoteHost")
|
||||
}
|
||||
if raw.RemotePort == "" {
|
||||
return nullErr("RemotePort")
|
||||
}
|
||||
remote.RemoteAddr = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
|
||||
if raw.NumConn <= 0 {
|
||||
remote.NumConn = 1
|
||||
remote.Singleplex = true
|
||||
} else {
|
||||
remote.NumConn = raw.NumConn
|
||||
remote.Singleplex = false
|
||||
}
|
||||
|
||||
// Transport and (if TLS mode), browser
|
||||
switch strings.ToLower(raw.Transport) {
|
||||
case "cdn":
|
||||
var cdnDomainPort string
|
||||
if raw.CDNOriginHost == "" {
|
||||
cdnDomainPort = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
|
||||
} else {
|
||||
cdnDomainPort = net.JoinHostPort(raw.CDNOriginHost, raw.RemotePort)
|
||||
}
|
||||
if raw.CDNWsUrlPath == "" {
|
||||
raw.CDNWsUrlPath = "/"
|
||||
}
|
||||
|
||||
remote.Transport = TransportConfig{
|
||||
mode: "cdn",
|
||||
wsUrl: "ws://" + cdnDomainPort + raw.CDNWsUrlPath,
|
||||
}
|
||||
case "direct":
|
||||
fallthrough
|
||||
default:
|
||||
var browser browser
|
||||
switch strings.ToLower(raw.BrowserSig) {
|
||||
case "firefox":
|
||||
browser = firefox
|
||||
case "safari":
|
||||
browser = safari
|
||||
case "chrome":
|
||||
fallthrough
|
||||
default:
|
||||
browser = chrome
|
||||
}
|
||||
remote.Transport = TransportConfig{
|
||||
mode: "direct",
|
||||
browser: browser,
|
||||
}
|
||||
}
|
||||
|
||||
// KeepAlive
|
||||
if raw.KeepAlive <= 0 {
|
||||
remote.KeepAlive = -1
|
||||
} else {
|
||||
remote.KeepAlive = remote.KeepAlive * time.Second
|
||||
}
|
||||
|
||||
if raw.LocalHost == "" {
|
||||
return nullErr("LocalHost")
|
||||
}
|
||||
if raw.LocalPort == "" {
|
||||
return nullErr("LocalPort")
|
||||
}
|
||||
local.LocalAddr = net.JoinHostPort(raw.LocalHost, raw.LocalPort)
|
||||
// stream no write timeout
|
||||
if raw.StreamTimeout == 0 {
|
||||
local.Timeout = 300 * time.Second
|
||||
} else {
|
||||
local.Timeout = time.Duration(raw.StreamTimeout) * time.Second
|
||||
}
|
||||
|
||||
return
|
||||
sta.staticPub = pub
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,37 +1,20 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
ssv := "UID=iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=;PublicKey=IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=;" +
|
||||
"ServerName=www.bing.com;NumConn=4;MaskBrowser=chrome;ProxyMethod=shadowsocks;EncryptionMethod=plain"
|
||||
func TestSSVtoJson(t *testing.T) {
|
||||
ssv := "UID=iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=;PublicKey=IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=;ServerName=www.bing.com;TicketTimeHint=3600;NumConn=4;MaskBrowser=chrome;"
|
||||
json := ssvToJson(ssv)
|
||||
expected := []byte(`{"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=","PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=","ServerName":"www.bing.com","NumConn":4,"MaskBrowser":"chrome","ProxyMethod":"shadowsocks","EncryptionMethod":"plain"}`)
|
||||
|
||||
t.Run("byte equality", func(t *testing.T) {
|
||||
assert.Equal(t, expected, json)
|
||||
})
|
||||
|
||||
t.Run("struct equality", func(t *testing.T) {
|
||||
tmpConfig, _ := ioutil.TempFile("", "ck_client_config")
|
||||
_, _ = tmpConfig.Write(expected)
|
||||
parsedFromSSV, err := ParseConfig(ssv)
|
||||
assert.NoError(t, err)
|
||||
parsedFromJson, err := ParseConfig(tmpConfig.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, parsedFromJson, parsedFromSSV)
|
||||
})
|
||||
|
||||
t.Run("empty file", func(t *testing.T) {
|
||||
tmpConfig, _ := ioutil.TempFile("", "ck_client_config")
|
||||
_, err := ParseConfig(tmpConfig.Name())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
expected := []byte(`{"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=","PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=","ServerName":"www.bing.com","TicketTimeHint":3600,"NumConn":4,"MaskBrowser":"chrome"}`)
|
||||
if !bytes.Equal(expected, json) {
|
||||
t.Error(
|
||||
"For", "ssvToJson",
|
||||
"expecting", string(expected),
|
||||
"got", string(json),
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type Transport interface {
|
||||
Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error)
|
||||
net.Conn
|
||||
}
|
||||
|
||||
type TransportConfig struct {
|
||||
mode string
|
||||
|
||||
wsUrl string
|
||||
|
||||
browser browser
|
||||
}
|
||||
|
||||
func (t TransportConfig) CreateTransport() Transport {
|
||||
switch t.mode {
|
||||
case "cdn":
|
||||
return &WSOverTLS{
|
||||
wsUrl: t.wsUrl,
|
||||
}
|
||||
case "direct":
|
||||
return &DirectTLS{
|
||||
browser: t.browser,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/gorilla/websocket"
|
||||
utls "github.com/refraction-networking/utls"
|
||||
)
|
||||
|
||||
type WSOverTLS struct {
|
||||
*common.WebSocketConn
|
||||
wsUrl string
|
||||
}
|
||||
|
||||
func (ws *WSOverTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) {
|
||||
utlsConfig := &utls.Config{
|
||||
ServerName: authInfo.MockDomain,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
uconn := utls.UClient(rawConn, utlsConfig, utls.HelloChrome_Auto)
|
||||
err = uconn.BuildHandshakeState()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for i, extension := range uconn.Extensions {
|
||||
_, ok := extension.(*utls.ALPNExtension)
|
||||
if ok {
|
||||
uconn.Extensions = append(uconn.Extensions[:i], uconn.Extensions[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = uconn.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(ws.wsUrl)
|
||||
if err != nil {
|
||||
return sessionKey, fmt.Errorf("failed to parse ws url: %v", err)
|
||||
}
|
||||
|
||||
payload, sharedSecret := makeAuthenticationPayload(authInfo)
|
||||
header := http.Header{}
|
||||
header.Add("hidden", base64.StdEncoding.EncodeToString(append(payload.randPubKey[:], payload.ciphertextWithTag[:]...)))
|
||||
c, _, err := websocket.NewClient(uconn, u, header, 16480, 16480)
|
||||
if err != nil {
|
||||
return sessionKey, fmt.Errorf("failed to handshake: %v", err)
|
||||
}
|
||||
|
||||
ws.WebSocketConn = &common.WebSocketConn{Conn: c}
|
||||
|
||||
buf := make([]byte, 128)
|
||||
n, err := ws.Read(buf)
|
||||
if err != nil {
|
||||
return sessionKey, fmt.Errorf("failed to read reply: %v", err)
|
||||
}
|
||||
|
||||
if n != 60 {
|
||||
return sessionKey, errors.New("reply must be 60 bytes")
|
||||
}
|
||||
|
||||
reply := buf[:60]
|
||||
sessionKeySlice, err := common.AESGCMDecrypt(reply[:12], sharedSecret[:], reply[12:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
copy(sessionKey[:], sessionKeySlice)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (ws *WSOverTLS) Close() error {
|
||||
if ws.WebSocketConn != nil {
|
||||
return ws.WebSocketConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
/*
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
/*
|
||||
Forked from https://golang.org/src/io/io.go
|
||||
*/
|
||||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
func Copy(dst net.Conn, src net.Conn) (written int64, err error) {
|
||||
defer func() { src.Close(); dst.Close() }()
|
||||
|
||||
// If the reader has a WriteTo method, use it to do the copy.
|
||||
// Avoids an allocation and a copy.
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
|
||||
size := 32 * 1024
|
||||
buf := make([]byte, size)
|
||||
for {
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aesgcm.NonceSize() {
|
||||
// check here so it doesn't panic
|
||||
return nil, errors.New("incorrect nonce size")
|
||||
}
|
||||
|
||||
return aesgcm.Seal(nil, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func AESGCMDecrypt(nonce []byte, key []byte, ciphertext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aesgcm.NonceSize() {
|
||||
// check here so it doesn't panic
|
||||
return nil, errors.New("incorrect nonce size")
|
||||
}
|
||||
plain, err := aesgcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func CryptoRandRead(buf []byte) {
|
||||
RandRead(rand.Reader, buf)
|
||||
}
|
||||
|
||||
func backoff(f func() error) {
|
||||
err := f()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
waitDur := [10]time.Duration{5 * time.Millisecond, 10 * time.Millisecond, 30 * time.Millisecond, 50 * time.Millisecond,
|
||||
100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
|
||||
3 * time.Second, 5 * time.Second}
|
||||
for i := 0; i < 10; i++ {
|
||||
log.Errorf("Failed to get random: %v. Retrying...", err)
|
||||
err = f()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(waitDur[i])
|
||||
}
|
||||
log.Fatal("Cannot get random after 10 retries")
|
||||
}
|
||||
|
||||
func RandRead(randSource io.Reader, buf []byte) {
|
||||
backoff(func() error {
|
||||
_, err := randSource.Read(buf)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func RandItem[T any](list []T) T {
|
||||
return list[RandInt(len(list))]
|
||||
}
|
||||
|
||||
func RandInt(n int) int {
|
||||
s := new(int)
|
||||
backoff(func() error {
|
||||
size, err := rand.Int(rand.Reader, big.NewInt(int64(n)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = int(size.Int64())
|
||||
return nil
|
||||
})
|
||||
return *s
|
||||
}
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const gcmTagSize = 16
|
||||
|
||||
func TestAESGCM(t *testing.T) {
|
||||
// test vectors from https://luca-giuzzi.unibs.it/corsi/Support/papers-cryptography/gcm-spec.pdf
|
||||
t.Run("correct 128", func(t *testing.T) {
|
||||
key, _ := hex.DecodeString("00000000000000000000000000000000")
|
||||
plaintext, _ := hex.DecodeString("")
|
||||
nonce, _ := hex.DecodeString("000000000000000000000000")
|
||||
ciphertext, _ := hex.DecodeString("")
|
||||
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
|
||||
|
||||
encryptedWithTag, err := AESGCMEncrypt(nonce, key, plaintext)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ciphertext, encryptedWithTag[:len(plaintext)])
|
||||
assert.Equal(t, tag, encryptedWithTag[len(plaintext):len(plaintext)+gcmTagSize])
|
||||
|
||||
decrypted, err := AESGCMDecrypt(nonce, key, encryptedWithTag)
|
||||
assert.NoError(t, err)
|
||||
// slight inconvenience here that assert.Equal does not consider a nil slice and an empty slice to be
|
||||
// equal. decrypted should be []byte(nil) but plaintext is []byte{}
|
||||
assert.True(t, bytes.Equal(plaintext, decrypted))
|
||||
})
|
||||
t.Run("bad key size", func(t *testing.T) {
|
||||
key, _ := hex.DecodeString("0000000000000000000000000000")
|
||||
plaintext, _ := hex.DecodeString("")
|
||||
nonce, _ := hex.DecodeString("000000000000000000000000")
|
||||
ciphertext, _ := hex.DecodeString("")
|
||||
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
|
||||
|
||||
_, err := AESGCMEncrypt(nonce, key, plaintext)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("bad nonce size", func(t *testing.T) {
|
||||
key, _ := hex.DecodeString("00000000000000000000000000000000")
|
||||
plaintext, _ := hex.DecodeString("")
|
||||
nonce, _ := hex.DecodeString("00000000000000000000")
|
||||
ciphertext, _ := hex.DecodeString("")
|
||||
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
|
||||
|
||||
_, err := AESGCMEncrypt(nonce, key, plaintext)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("bad tag", func(t *testing.T) {
|
||||
key, _ := hex.DecodeString("00000000000000000000000000000000")
|
||||
nonce, _ := hex.DecodeString("00000000000000000000")
|
||||
ciphertext, _ := hex.DecodeString("")
|
||||
tag, _ := hex.DecodeString("fffffccefa7e3061367f1d57a4e745ff")
|
||||
|
||||
_, err := AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
fails int
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (f *failingReader) Read(p []byte) (n int, err error) {
|
||||
if f.fails > 0 {
|
||||
f.fails -= 1
|
||||
return 0, errors.New("no data for you yet")
|
||||
} else {
|
||||
return f.reader.Read(p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandRead(t *testing.T) {
|
||||
failer := &failingReader{
|
||||
fails: 3,
|
||||
reader: rand.New(rand.NewSource(0)),
|
||||
}
|
||||
readBuf := make([]byte, 10)
|
||||
RandRead(failer, readBuf)
|
||||
assert.NotEqual(t, [10]byte{}, readBuf)
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
package common
|
||||
|
||||
import "net"
|
||||
|
||||
type Dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
}
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
VersionTLS11 = 0x0301
|
||||
VersionTLS13 = 0x0303
|
||||
|
||||
recordLayerLength = 5
|
||||
|
||||
Handshake = 22
|
||||
ApplicationData = 23
|
||||
|
||||
initialWriteBufSize = 14336
|
||||
)
|
||||
|
||||
func AddRecordLayer(input []byte, typ byte, ver uint16) []byte {
|
||||
msgLen := len(input)
|
||||
retLen := msgLen + recordLayerLength
|
||||
var ret []byte
|
||||
ret = make([]byte, retLen)
|
||||
copy(ret[recordLayerLength:], input)
|
||||
ret[0] = typ
|
||||
ret[1] = byte(ver >> 8)
|
||||
ret[2] = byte(ver)
|
||||
ret[3] = byte(msgLen >> 8)
|
||||
ret[4] = byte(msgLen)
|
||||
return ret
|
||||
}
|
||||
|
||||
type TLSConn struct {
|
||||
net.Conn
|
||||
writeBufPool sync.Pool
|
||||
}
|
||||
|
||||
func NewTLSConn(conn net.Conn) *TLSConn {
|
||||
return &TLSConn{
|
||||
Conn: conn,
|
||||
writeBufPool: sync.Pool{New: func() interface{} {
|
||||
b := make([]byte, 0, initialWriteBufSize)
|
||||
b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF))
|
||||
return &b
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (tls *TLSConn) LocalAddr() net.Addr {
|
||||
return tls.Conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (tls *TLSConn) RemoteAddr() net.Addr {
|
||||
return tls.Conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (tls *TLSConn) SetDeadline(t time.Time) error {
|
||||
return tls.Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (tls *TLSConn) SetReadDeadline(t time.Time) error {
|
||||
return tls.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (tls *TLSConn) SetWriteDeadline(t time.Time) error {
|
||||
return tls.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
|
||||
// TCP is a stream. Multiple TLS messages can arrive at the same time,
|
||||
// a single message can also be segmented due to MTU of the IP layer.
|
||||
// This function guareentees a single TLS message to be read and everything
|
||||
// else is left in the buffer.
|
||||
if len(buffer) < recordLayerLength {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
_, err = io.ReadFull(tls.Conn, buffer[:recordLayerLength])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dataLength := int(binary.BigEndian.Uint16(buffer[3:5]))
|
||||
if dataLength > len(buffer) {
|
||||
err = io.ErrShortBuffer
|
||||
return
|
||||
}
|
||||
// we overwrite the record layer here
|
||||
return io.ReadFull(tls.Conn, buffer[:dataLength])
|
||||
}
|
||||
|
||||
func (tls *TLSConn) Write(in []byte) (n int, err error) {
|
||||
msgLen := len(in)
|
||||
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
|
||||
return 0, errors.New("message is too long")
|
||||
}
|
||||
writeBuf := tls.writeBufPool.Get().(*[]byte)
|
||||
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
|
||||
*writeBuf = append(*writeBuf, in...)
|
||||
n, err = tls.Conn.Write(*writeBuf)
|
||||
*writeBuf = (*writeBuf)[:3]
|
||||
tls.writeBufPool.Put(writeBuf)
|
||||
return n - recordLayerLength, err
|
||||
}
|
||||
|
||||
func (tls *TLSConn) Close() error {
|
||||
return tls.Conn.Close()
|
||||
}
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkTLSConn_Write(b *testing.B) {
|
||||
const bufSize = 16 * 1024
|
||||
addrCh := make(chan string, 1)
|
||||
go func() {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
addrCh <- listener.Addr().String()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
readBuf := make([]byte, bufSize*2)
|
||||
for {
|
||||
_, err = conn.Read(readBuf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
data := make([]byte, bufSize)
|
||||
discardConn, _ := net.Dial("tcp", <-addrCh)
|
||||
tlsConn := NewTLSConn(discardConn)
|
||||
defer tlsConn.Close()
|
||||
b.SetBytes(bufSize)
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
tlsConn.Write(data)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocketConn implements io.ReadWriteCloser
|
||||
// it makes websocket.Conn binary-oriented
|
||||
type WebSocketConn struct {
|
||||
*websocket.Conn
|
||||
writeM sync.Mutex
|
||||
}
|
||||
|
||||
func (ws *WebSocketConn) Write(data []byte) (int, error) {
|
||||
ws.writeM.Lock()
|
||||
err := ws.WriteMessage(websocket.BinaryMessage, data)
|
||||
ws.writeM.Unlock()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return len(data), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WebSocketConn) Read(buf []byte) (n int, err error) {
|
||||
t, r, err := ws.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if t != websocket.BinaryMessage {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read until io.EOL for one full message
|
||||
for {
|
||||
var read int
|
||||
read, err = r.Read(buf[n:])
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// There may be data available to read but n == len(buf)-1, read==0 because buffer is full
|
||||
if read == 0 {
|
||||
err = errors.New("nothing more is read. message may be larger than buffer")
|
||||
break
|
||||
}
|
||||
}
|
||||
n += read
|
||||
}
|
||||
return
|
||||
}
|
||||
func (ws *WebSocketConn) Close() error {
|
||||
ws.writeM.Lock()
|
||||
defer ws.writeM.Unlock()
|
||||
return ws.Conn.Close()
|
||||
}
|
||||
|
||||
func (ws *WebSocketConn) SetDeadline(t time.Time) error {
|
||||
err := ws.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ws.SetWriteDeadline(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
var RealWorldState = WorldState{
|
||||
Rand: rand.Reader,
|
||||
Now: time.Now,
|
||||
}
|
||||
|
||||
type WorldState struct {
|
||||
Rand io.Reader
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
func WorldOfTime(t time.Time) WorldState {
|
||||
return WorldState{
|
||||
Rand: rand.Reader,
|
||||
Now: func() time.Time { return t },
|
||||
}
|
||||
}
|
||||
|
|
@ -68,11 +68,13 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) {
|
|||
return &pub, true
|
||||
}
|
||||
|
||||
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) {
|
||||
var priv, pub *[32]byte
|
||||
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte {
|
||||
var priv, pub, secret *[32]byte
|
||||
|
||||
priv = privKey.(*[32]byte)
|
||||
pub = pubKey.(*[32]byte)
|
||||
secret = new([32]byte)
|
||||
|
||||
return curve25519.X25519(priv[:], pub[:])
|
||||
curve25519.ScalarMult(secret, priv, pub)
|
||||
return secret[:]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
// This code is forked from https://github.com/wsddn/go-ecdh/blob/master/curve25519.go
|
||||
/*
|
||||
Copyright (c) 2014, tang0th
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of tang0th nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
package ecdh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCurve25519(t *testing.T) {
|
||||
testECDH(t)
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
reader, writer := io.Pipe()
|
||||
_ = writer.Close()
|
||||
_, _, err := GenerateKey(reader)
|
||||
if err == nil {
|
||||
t.Error("GenerateKey should return error")
|
||||
}
|
||||
|
||||
_, ok := Unmarshal([]byte{1})
|
||||
if ok {
|
||||
t.Error("Unmarshal should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCurve25519(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
testECDH(b)
|
||||
}
|
||||
}
|
||||
|
||||
func testECDH(t testing.TB) {
|
||||
var privKey1, privKey2 crypto.PrivateKey
|
||||
var pubKey1, pubKey2 crypto.PublicKey
|
||||
var pubKey1Buf, pubKey2Buf []byte
|
||||
var err error
|
||||
var ok bool
|
||||
var secret1, secret2 []byte
|
||||
|
||||
privKey1, pubKey1, err = GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
privKey2, pubKey2, err = GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
pubKey1Buf = Marshal(pubKey1)
|
||||
pubKey2Buf = Marshal(pubKey2)
|
||||
|
||||
pubKey1, ok = Unmarshal(pubKey1Buf)
|
||||
if !ok {
|
||||
t.Fatalf("Unmarshal does not work")
|
||||
}
|
||||
|
||||
pubKey2, ok = Unmarshal(pubKey2Buf)
|
||||
if !ok {
|
||||
t.Fatalf("Unmarshal does not work")
|
||||
}
|
||||
|
||||
secret1, err = GenerateSharedSecret(privKey1, pubKey2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
secret2, err = GenerateSharedSecret(privKey2, pubKey1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(secret1, secret2) {
|
||||
t.Fatalf("The two shared keys: %d, %d do not match", secret1, secret2)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
|
||||
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// datagramBufferedPipe is the same as streamBufferedPipe with the exception that it's message-oriented,
|
||||
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
|
||||
// it won't get chopped up into individual bytes
|
||||
type datagramBufferedPipe struct {
|
||||
pLens []int
|
||||
buf *bytes.Buffer
|
||||
closed bool
|
||||
rwCond *sync.Cond
|
||||
wtTimeout time.Duration
|
||||
rDeadline time.Time
|
||||
|
||||
timeoutTimer *time.Timer
|
||||
}
|
||||
|
||||
func NewDatagramBufferedPipe() *datagramBufferedPipe {
|
||||
d := &datagramBufferedPipe{
|
||||
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||
buf: new(bytes.Buffer),
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
for {
|
||||
if d.closed && len(d.pLens) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
hasRDeadline := !d.rDeadline.IsZero()
|
||||
if hasRDeadline {
|
||||
if time.Until(d.rDeadline) <= 0 {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
}
|
||||
|
||||
if len(d.pLens) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if hasRDeadline {
|
||||
d.broadcastAfter(time.Until(d.rDeadline))
|
||||
}
|
||||
d.rwCond.Wait()
|
||||
}
|
||||
dataLen := d.pLens[0]
|
||||
if len(target) < dataLen {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
d.pLens = d.pLens[1:]
|
||||
d.buf.Read(target[:dataLen])
|
||||
// err will always be nil because we have already verified that buf.Len() != 0
|
||||
d.rwCond.Broadcast()
|
||||
return dataLen, nil
|
||||
}
|
||||
|
||||
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
for {
|
||||
if d.closed {
|
||||
return true, io.ErrClosedPipe
|
||||
}
|
||||
if d.buf.Len() <= recvBufferSizeLimit {
|
||||
// if d.buf gets too large, write() will panic. We don't want this to happen
|
||||
break
|
||||
}
|
||||
d.rwCond.Wait()
|
||||
}
|
||||
|
||||
if f.Closing != closingNothing {
|
||||
d.closed = true
|
||||
d.rwCond.Broadcast()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
dataLen := len(f.Payload)
|
||||
d.pLens = append(d.pLens, dataLen)
|
||||
d.buf.Write(f.Payload)
|
||||
// err will always be nil
|
||||
d.rwCond.Broadcast()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (d *datagramBufferedPipe) Close() error {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
|
||||
d.closed = true
|
||||
d.rwCond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *datagramBufferedPipe) SetReadDeadline(t time.Time) {
|
||||
d.rwCond.L.Lock()
|
||||
defer d.rwCond.L.Unlock()
|
||||
|
||||
d.rDeadline = t
|
||||
d.rwCond.Broadcast()
|
||||
}
|
||||
|
||||
func (d *datagramBufferedPipe) broadcastAfter(t time.Duration) {
|
||||
if d.timeoutTimer != nil {
|
||||
d.timeoutTimer.Stop()
|
||||
}
|
||||
d.timeoutTimer = time.AfterFunc(t, d.rwCond.Broadcast)
|
||||
}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDatagramBuffer_RW(t *testing.T) {
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
t.Run("simple write", func(t *testing.T) {
|
||||
pipe := NewDatagramBufferedPipe()
|
||||
_, err := pipe.Write(&Frame{Payload: b})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("simple read", func(t *testing.T) {
|
||||
pipe := NewDatagramBufferedPipe()
|
||||
_, _ = pipe.Write(&Frame{Payload: b})
|
||||
b2 := make([]byte, len(b))
|
||||
n, err := pipe.Read(b2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(b), n)
|
||||
assert.Equal(t, b, b2)
|
||||
assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
|
||||
})
|
||||
|
||||
t.Run("writing closing frame", func(t *testing.T) {
|
||||
pipe := NewDatagramBufferedPipe()
|
||||
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, toBeClosed, "should be to be closed")
|
||||
assert.True(t, pipe.closed, "pipe should be closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatagramBuffer_BlockingRead(t *testing.T) {
|
||||
pipe := NewDatagramBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
go func() {
|
||||
time.Sleep(readBlockTime)
|
||||
pipe.Write(&Frame{Payload: b})
|
||||
}()
|
||||
b2 := make([]byte, len(b))
|
||||
n, err := pipe.Read(b2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
|
||||
assert.Equal(t, b, b2)
|
||||
}
|
||||
|
||||
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
||||
pipe := NewDatagramBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
pipe.Write(&Frame{Payload: b})
|
||||
b2 := make([]byte, len(b))
|
||||
pipe.Close()
|
||||
n, err := pipe.Read(b2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
|
||||
assert.Equal(t, b, b2)
|
||||
}
|
||||
|
|
@ -1,14 +1,8 @@
|
|||
package multiplex
|
||||
|
||||
const (
|
||||
closingNothing = iota
|
||||
closingStream
|
||||
closingSession
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
StreamID uint32
|
||||
Seq uint64
|
||||
Seq uint32
|
||||
Closing uint8
|
||||
Payload []byte
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
//"log"
|
||||
)
|
||||
|
||||
// The data is multiplexed through several TCP connections, therefore the
|
||||
// order of arrival is not guaranteed. A stream's first packet may be sent through
|
||||
// connection0 and its second packet may be sent through connection1. Although both
|
||||
// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the
|
||||
// remote side before packet0.
|
||||
//
|
||||
// However, shadowsocks' protocol does not provide sequence control. We must therefore
|
||||
// make sure packets arrive in order.
|
||||
//
|
||||
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
|
||||
// they should be sent to shadowsocks. The code in this file provides buffering and sorting.
|
||||
//
|
||||
// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around.
|
||||
//
|
||||
// Note that in golang, integer overflow results in wrap around
|
||||
//
|
||||
// Stream.nextRecvSeq is the expected sequence number of the next packet
|
||||
// Stream.rev counts the amount of time the sequence number gets wrapped
|
||||
|
||||
type frameNode struct {
|
||||
trueSeq uint64
|
||||
frame *Frame
|
||||
}
|
||||
type sorterHeap []*frameNode
|
||||
|
||||
func (sh sorterHeap) Less(i, j int) bool {
|
||||
return sh[i].trueSeq < sh[j].trueSeq
|
||||
}
|
||||
func (sh sorterHeap) Len() int {
|
||||
return len(sh)
|
||||
}
|
||||
func (sh sorterHeap) Swap(i, j int) {
|
||||
sh[i], sh[j] = sh[j], sh[i]
|
||||
}
|
||||
|
||||
func (sh *sorterHeap) Push(x interface{}) {
|
||||
*sh = append(*sh, x.(*frameNode))
|
||||
}
|
||||
|
||||
func (sh *sorterHeap) Pop() interface{} {
|
||||
old := *sh
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*sh = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func (s *Stream) writeNewFrame(f *Frame) {
|
||||
s.newFrameCh <- f
|
||||
}
|
||||
|
||||
// recvNewFrame is a forever running loop which receives frames unordered,
|
||||
// cache and order them and send them into sortedBufCh
|
||||
func (s *Stream) recvNewFrame() {
|
||||
for {
|
||||
var f *Frame
|
||||
select {
|
||||
case <-s.die:
|
||||
return
|
||||
case f = <-s.newFrameCh:
|
||||
}
|
||||
if f == nil { // This shouldn't happen
|
||||
//log.Println("nil frame")
|
||||
continue
|
||||
}
|
||||
|
||||
// when there's no ooo packages in heap and we receive the next package in order
|
||||
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
|
||||
if f.Closing == 1 {
|
||||
// empty data indicates closing signal
|
||||
s.sortedBufCh <- []byte{}
|
||||
return
|
||||
} else {
|
||||
s.sortedBufCh <- f.Payload
|
||||
s.nextRecvSeq += 1
|
||||
if s.nextRecvSeq == 0 { // getting wrapped
|
||||
s.rev += 1
|
||||
s.wrapMode = false
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fs := &frameNode{
|
||||
trueSeq: 0,
|
||||
frame: f,
|
||||
}
|
||||
|
||||
if f.Seq < s.nextRecvSeq {
|
||||
// For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255
|
||||
// e.g. we are on rev=0 (wrap has not happened yet)
|
||||
// and we get the order of recv as 253 254 0 1
|
||||
// after 254, nextN should be 255, but 0 is received and 0 < 255
|
||||
// now 0 should have a trueSeq of 256
|
||||
if !s.wrapMode {
|
||||
// wrapMode is true when the latest seq is wrapped but nextN is not
|
||||
s.wrapMode = true
|
||||
}
|
||||
fs.trueSeq = uint64(1<<32)*uint64(s.rev+1) + uint64(f.Seq) + 1
|
||||
// +1 because wrapped 0 should have trueSeq of 256 instead of 255
|
||||
// when this bit was run on 1, the trueSeq of 1 would become 256
|
||||
} else {
|
||||
fs.trueSeq = uint64(1<<32)*uint64(s.rev) + uint64(f.Seq)
|
||||
// when this bit was run on 255, the trueSeq of 255 would be 255
|
||||
}
|
||||
|
||||
heap.Push(&s.sh, fs)
|
||||
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||
for len(s.sh) > 0 && s.sh[0].frame.Seq == s.nextRecvSeq {
|
||||
f = heap.Pop(&s.sh).(*frameNode).frame
|
||||
if f.Closing == 1 {
|
||||
// empty data indicates closing signal
|
||||
s.sortedBufCh <- []byte{}
|
||||
return
|
||||
} else {
|
||||
s.sortedBufCh <- f.Payload
|
||||
s.nextRecvSeq += 1
|
||||
if s.nextRecvSeq == 0 { // getting wrapped
|
||||
s.rev += 1
|
||||
s.wrapMode = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
//"log"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecvNewFrame(t *testing.T) {
|
||||
inOrder := []uint64{5, 6, 7, 8, 9, 10, 11}
|
||||
outOfOrder0 := []uint64{5, 7, 8, 6, 11, 10, 9}
|
||||
outOfOrder1 := []uint64{1, 96, 47, 2, 29, 18, 60, 8, 74, 22, 82, 58, 44, 51, 57, 71, 90, 94, 68, 83, 61, 91, 39, 97, 85, 63, 46, 73, 54, 84, 76, 98, 93, 79, 75, 50, 67, 37, 92, 99, 42, 77, 17, 16, 38, 3, 100, 24, 31, 7, 36, 40, 86, 64, 34, 45, 12, 5, 9, 27, 21, 26, 35, 6, 65, 69, 53, 4, 48, 28, 30, 56, 32, 11, 80, 66, 25, 41, 78, 13, 88, 62, 15, 70, 49, 43, 72, 23, 10, 55, 52, 95, 14, 59, 87, 33, 19, 20, 81, 89}
|
||||
outOfOrderWrap0 := []uint64{1<<32 - 5, 1<<32 + 3, 1 << 32, 1<<32 - 3, 1<<32 - 4, 1<<32 + 2, 1<<32 - 2, 1<<32 - 1, 1<<32 + 1}
|
||||
sets := [][]uint64{inOrder, outOfOrder0, outOfOrder1, outOfOrderWrap0}
|
||||
for _, set := range sets {
|
||||
stream := makeStream(1, &Session{})
|
||||
stream.nextRecvSeq = uint32(set[0])
|
||||
for _, n := range set {
|
||||
bu64 := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(bu64, n)
|
||||
frame := &Frame{
|
||||
Seq: uint32(n),
|
||||
Payload: bu64,
|
||||
}
|
||||
stream.writeNewFrame(frame)
|
||||
}
|
||||
|
||||
var testSorted []uint32
|
||||
for x := 0; x < len(set); x++ {
|
||||
p := <-stream.sortedBufCh
|
||||
//log.Print(p)
|
||||
testSorted = append(testSorted, uint32(binary.BigEndian.Uint64(p)))
|
||||
}
|
||||
sorted64 := make([]uint64, len(set))
|
||||
copy(sorted64, set)
|
||||
sort.Slice(sorted64, func(i, j int) bool { return sorted64[i] < sorted64[j] })
|
||||
sorted32 := make([]uint32, len(set))
|
||||
for i, _ := range sorted64 {
|
||||
sorted32[i] = uint32(sorted64[i])
|
||||
}
|
||||
|
||||
for i, _ := range sorted32 {
|
||||
if sorted32[i] != testSorted[i] {
|
||||
t.Error(
|
||||
"For", set,
|
||||
"expecting", sorted32,
|
||||
"got", testSorted,
|
||||
)
|
||||
}
|
||||
}
|
||||
close(stream.die)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/connutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func serveEcho(l net.Listener) {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
// TODO: pass the error back
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
_, err := io.Copy(conn, conn)
|
||||
if err != nil {
|
||||
// TODO: pass the error back
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
type connPair struct {
|
||||
clientConn net.Conn
|
||||
serverConn net.Conn
|
||||
}
|
||||
|
||||
func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
|
||||
sessionKey := [32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
|
||||
sessionId := 1
|
||||
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
|
||||
clientConfig := SessionConfig{
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: false,
|
||||
}
|
||||
serverConfig := clientConfig
|
||||
|
||||
clientSession := MakeSession(uint32(sessionId), clientConfig)
|
||||
serverSession := MakeSession(uint32(sessionId), serverConfig)
|
||||
|
||||
paris := make([]*connPair, numConn)
|
||||
for i := 0; i < numConn; i++ {
|
||||
c, s := connutil.AsyncPipe()
|
||||
clientConn := common.NewTLSConn(c)
|
||||
serverConn := common.NewTLSConn(s)
|
||||
paris[i] = &connPair{
|
||||
clientConn: clientConn,
|
||||
serverConn: serverConn,
|
||||
}
|
||||
clientSession.AddConnection(clientConn)
|
||||
serverSession.AddConnection(serverConn)
|
||||
}
|
||||
return clientSession, serverSession, paris
|
||||
}
|
||||
|
||||
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, conn := range conns {
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
testData := make([]byte, msgLen)
|
||||
rand.Read(testData)
|
||||
|
||||
// we cannot call t.Fatalf in concurrent contexts
|
||||
n, err := conn.Write(testData)
|
||||
if n != msgLen {
|
||||
t.Errorf("written only %v, err %v", n, err)
|
||||
return
|
||||
}
|
||||
|
||||
recvBuf := make([]byte, msgLen)
|
||||
_, err = io.ReadFull(conn, recvBuf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read back: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, recvBuf) {
|
||||
t.Errorf("echoed data not correct")
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMultiplex(t *testing.T) {
|
||||
const numStreams = 2000 // -race option limits the number of goroutines to 8192
|
||||
const numConns = 4
|
||||
const msgLen = 16384
|
||||
|
||||
clientSession, serverSession, _ := makeSessionPair(numConns)
|
||||
go serveEcho(serverSession)
|
||||
|
||||
streams := make([]net.Conn, numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
stream, err := clientSession.OpenStream()
|
||||
assert.NoError(t, err)
|
||||
streams[i] = stream
|
||||
}
|
||||
|
||||
//test echo
|
||||
runEchoTest(t, streams, msgLen)
|
||||
|
||||
assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
|
||||
assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
|
||||
|
||||
// close one stream
|
||||
closing, streams := streams[0], streams[1:]
|
||||
err := closing.Close()
|
||||
assert.NoError(t, err, "couldn't close a stream")
|
||||
_, err = closing.Write([]byte{0})
|
||||
assert.Equal(t, ErrBrokenStream, err)
|
||||
_, err = closing.Read(make([]byte, 1))
|
||||
assert.Equal(t, ErrBrokenStream, err)
|
||||
}
|
||||
|
||||
func TestMux_StreamClosing(t *testing.T) {
|
||||
clientSession, serverSession, _ := makeSessionPair(1)
|
||||
go serveEcho(serverSession)
|
||||
|
||||
// read after closing stream
|
||||
testData := make([]byte, 128)
|
||||
recvBuf := make([]byte, 128)
|
||||
toBeClosed, _ := clientSession.OpenStream()
|
||||
_, err := toBeClosed.Write(testData) // should be echoed back
|
||||
assert.NoError(t, err, "couldn't write to a stream")
|
||||
|
||||
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
|
||||
assert.NoError(t, err, "can't read anything before stream closed")
|
||||
|
||||
_ = toBeClosed.Close()
|
||||
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
|
||||
assert.NoError(t, err, "can't read residual data on stream")
|
||||
assert.Equal(t, testData, recvBuf, "incorrect data read back")
|
||||
}
|
||||
|
|
@ -1,201 +1,72 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/salsa20"
|
||||
"io"
|
||||
)
|
||||
|
||||
const frameHeaderLength = 14
|
||||
const salsa20NonceSize = 8
|
||||
type Obfser func(*Frame) ([]byte, error)
|
||||
type Deobfser func([]byte) (*Frame, error)
|
||||
|
||||
// maxExtraLen equals the max length of padding + AEAD tag.
|
||||
// It is 255 bytes because the extra len field in frame header is only one byte.
|
||||
const maxExtraLen = 1<<8 - 1
|
||||
var u32 = binary.BigEndian.Uint32
|
||||
|
||||
// padFirstNFrames specifies the number of initial frames to pad,
|
||||
// to avoid TLS-in-TLS detection
|
||||
const padFirstNFrames = 5
|
||||
const headerLen = 12
|
||||
|
||||
const (
|
||||
EncryptionMethodPlain = iota
|
||||
EncryptionMethodAES256GCM
|
||||
EncryptionMethodChaha20Poly1305
|
||||
EncryptionMethodAES128GCM
|
||||
)
|
||||
|
||||
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames.
|
||||
type Obfuscator struct {
|
||||
payloadCipher cipher.AEAD
|
||||
|
||||
sessionKey [32]byte
|
||||
// For each frame, the three parts of the header is xored with three keys.
|
||||
// The keys are generated from the SID and the payload of the frame.
|
||||
func genXorKeys(key, nonce []byte) (i uint32, ii uint32, iii uint8) {
|
||||
h := sha1.New()
|
||||
hashed := h.Sum(append(key, nonce...))
|
||||
return u32(hashed[0:4]), u32(hashed[4:8]), hashed[8]
|
||||
}
|
||||
|
||||
// obfuscate adds multiplexing headers, encrypt and add TLS header
|
||||
func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
|
||||
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header
|
||||
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use
|
||||
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead())
|
||||
// as nonce for Salsa20 to encrypt the frame header. Both with sessionKey as keys.
|
||||
//
|
||||
// Several cryptographic guarantees we have made here: that payloadCipher, as an AEAD, is given a unique
|
||||
// iv/nonce each time, relative to its key; that the frame header encryptor Salsa20 is given a unique
|
||||
// nonce each time, relative to its key; and that the authenticity of frame header is checked.
|
||||
//
|
||||
// The payloadCipher is given a unique iv/nonce each time because it is derived from the frame header, which
|
||||
// contains the monotonically increasing stream id (uint32) and frame sequence (uint64). There will be a nonce
|
||||
// reuse after 2^64-1 frames sent (sent, not received because frames going different ways are sequenced
|
||||
// independently) by a stream, or after 2^32-1 streams created in a single session. We consider these number
|
||||
// to be large enough that they may never happen in reasonable time frames. Of course, different sessions
|
||||
// will produce the same combination of stream id and frame sequence, but they will have different session keys.
|
||||
//
|
||||
//
|
||||
// Because the frame header, before it being encrypted, is fed into the AEAD, it is also authenticated.
|
||||
// (rfc5116 s.2.1 "The nonce is authenticated internally to the algorithm").
|
||||
//
|
||||
// In case the user chooses to not encrypt the frame payload, payloadCipher will be nil. In this scenario,
|
||||
// we generate random bytes to be used as salsa20 nonce.
|
||||
payloadLen := len(f.Payload)
|
||||
if payloadLen == 0 {
|
||||
return 0, errors.New("payload cannot be empty")
|
||||
}
|
||||
tagLen := 0
|
||||
if o.payloadCipher != nil {
|
||||
tagLen = o.payloadCipher.Overhead()
|
||||
} else {
|
||||
tagLen = salsa20NonceSize
|
||||
}
|
||||
// Pad to avoid size side channel leak
|
||||
padLen := 0
|
||||
if f.Seq < padFirstNFrames {
|
||||
padLen = common.RandInt(maxExtraLen - tagLen + 1)
|
||||
}
|
||||
func MakeObfs(key []byte) Obfser {
|
||||
obfs := func(f *Frame) ([]byte, error) {
|
||||
obfsedHeader := make([]byte, headerLen)
|
||||
// header: [StreamID 4 bytes][Seq 4 bytes][Closing 1 byte][Nonce 3 bytes]
|
||||
io.ReadFull(rand.Reader, obfsedHeader[9:12])
|
||||
i, ii, iii := genXorKeys(key, obfsedHeader[9:12])
|
||||
binary.BigEndian.PutUint32(obfsedHeader[0:4], f.StreamID^i)
|
||||
binary.BigEndian.PutUint32(obfsedHeader[4:8], f.Seq^ii)
|
||||
obfsedHeader[8] = f.Closing ^ iii
|
||||
|
||||
usefulLen := frameHeaderLength + payloadLen + padLen + tagLen
|
||||
if len(buf) < usefulLen {
|
||||
return 0, errors.New("obfs buffer too small")
|
||||
// Composing final obfsed message
|
||||
// We don't use util.AddRecordLayer here to avoid unnecessary malloc
|
||||
obfsed := make([]byte, 5+headerLen+len(f.Payload))
|
||||
obfsed[0] = 0x17
|
||||
obfsed[1] = 0x03
|
||||
obfsed[2] = 0x03
|
||||
binary.BigEndian.PutUint16(obfsed[3:5], uint16(headerLen+len(f.Payload)))
|
||||
copy(obfsed[5:5+headerLen], obfsedHeader)
|
||||
copy(obfsed[5+headerLen:], f.Payload)
|
||||
// obfsed: [record layer 5 bytes][cipherheader 12 bytes][payload]
|
||||
return obfsed, nil
|
||||
}
|
||||
// we do as much in-place as possible to save allocation
|
||||
payload := buf[frameHeaderLength : frameHeaderLength+payloadLen+padLen]
|
||||
if payloadOffsetInBuf != frameHeaderLength {
|
||||
// if payload is not at the correct location in buffer
|
||||
copy(payload, f.Payload)
|
||||
}
|
||||
|
||||
header := buf[:frameHeaderLength]
|
||||
binary.BigEndian.PutUint32(header[0:4], f.StreamID)
|
||||
binary.BigEndian.PutUint64(header[4:12], f.Seq)
|
||||
header[12] = f.Closing
|
||||
header[13] = byte(padLen + tagLen)
|
||||
|
||||
// Random bytes for padding and nonce
|
||||
_, err := rand.Read(buf[frameHeaderLength+payloadLen : usefulLen])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to pad random: %w", err)
|
||||
}
|
||||
|
||||
if o.payloadCipher != nil {
|
||||
o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil)
|
||||
}
|
||||
|
||||
nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
|
||||
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey)
|
||||
|
||||
return usefulLen, nil
|
||||
return obfs
|
||||
}
|
||||
|
||||
// deobfuscate removes TLS header, decrypt and unmarshall frames
|
||||
func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error {
|
||||
if len(in) < frameHeaderLength+salsa20NonceSize {
|
||||
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize)
|
||||
}
|
||||
|
||||
header := in[:frameHeaderLength]
|
||||
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead
|
||||
|
||||
nonce := in[len(in)-salsa20NonceSize:]
|
||||
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey)
|
||||
|
||||
streamID := binary.BigEndian.Uint32(header[0:4])
|
||||
seq := binary.BigEndian.Uint64(header[4:12])
|
||||
closing := header[12]
|
||||
extraLen := header[13]
|
||||
|
||||
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
|
||||
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
|
||||
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
|
||||
}
|
||||
|
||||
var outputPayload []byte
|
||||
|
||||
if o.payloadCipher == nil {
|
||||
if extraLen == 0 {
|
||||
outputPayload = pldWithOverHead
|
||||
} else {
|
||||
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
||||
func MakeDeobfs(key []byte) Deobfser {
|
||||
deobfs := func(in []byte) (*Frame, error) {
|
||||
if len(in) < 5+headerLen {
|
||||
return nil, errors.New("Input cannot be shorter than 17 bytes")
|
||||
}
|
||||
} else {
|
||||
_, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
peeled := in[5:]
|
||||
i, ii, iii := genXorKeys(key, peeled[9:12])
|
||||
streamID := u32(peeled[0:4]) ^ i
|
||||
seq := u32(peeled[4:8]) ^ ii
|
||||
closing := peeled[8] ^ iii
|
||||
payload := make([]byte, len(peeled)-headerLen)
|
||||
copy(payload, peeled[headerLen:])
|
||||
ret := &Frame{
|
||||
StreamID: streamID,
|
||||
Seq: seq,
|
||||
Closing: closing,
|
||||
Payload: payload,
|
||||
}
|
||||
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
f.StreamID = streamID
|
||||
f.Seq = seq
|
||||
f.Closing = closing
|
||||
f.Payload = outputPayload
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) {
|
||||
o = Obfuscator{
|
||||
sessionKey: sessionKey,
|
||||
}
|
||||
switch encryptionMethod {
|
||||
case EncryptionMethodPlain:
|
||||
o.payloadCipher = nil
|
||||
case EncryptionMethodAES256GCM:
|
||||
var c cipher.Block
|
||||
c, err = aes.NewCipher(sessionKey[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
o.payloadCipher, err = cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case EncryptionMethodAES128GCM:
|
||||
var c cipher.Block
|
||||
c, err = aes.NewCipher(sessionKey[:16])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
o.payloadCipher, err = cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case EncryptionMethodChaha20Poly1305:
|
||||
o.payloadCipher, err = chacha20poly1305.New(sessionKey[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod)
|
||||
}
|
||||
|
||||
if o.payloadCipher != nil {
|
||||
if o.payloadCipher.NonceSize() > frameHeaderLength {
|
||||
return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return deobfs
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,276 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
func TestGenerateObfs(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
run := func(o Obfuscator, t *testing.T) {
|
||||
obfsBuf := make([]byte, 512)
|
||||
_testFrame, _ := quick.Value(reflect.TypeOf(Frame{}), rand.New(rand.NewSource(42)))
|
||||
testFrame := _testFrame.Interface().(Frame)
|
||||
i, err := o.obfuscate(&testFrame, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
var resultFrame Frame
|
||||
|
||||
err = o.deobfuscate(&resultFrame, obfsBuf[:i])
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, testFrame, resultFrame)
|
||||
}
|
||||
|
||||
t.Run("plain", func(t *testing.T) {
|
||||
o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
assert.NoError(t, err)
|
||||
run(o, t)
|
||||
})
|
||||
t.Run("aes-256-gcm", func(t *testing.T) {
|
||||
o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey)
|
||||
assert.NoError(t, err)
|
||||
run(o, t)
|
||||
})
|
||||
t.Run("aes-128-gcm", func(t *testing.T) {
|
||||
o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey)
|
||||
assert.NoError(t, err)
|
||||
run(o, t)
|
||||
})
|
||||
t.Run("chacha20-poly1305", func(t *testing.T) {
|
||||
o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
|
||||
assert.NoError(t, err)
|
||||
run(o, t)
|
||||
})
|
||||
t.Run("unknown encryption method", func(t *testing.T) {
|
||||
_, err := MakeObfuscator(0xff, sessionKey)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestObfuscate(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
const testPayloadLen = 1024
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
f := Frame{
|
||||
StreamID: 0,
|
||||
Seq: 0,
|
||||
Closing: 0,
|
||||
Payload: testPayload,
|
||||
}
|
||||
|
||||
runTest := func(t *testing.T, o Obfuscator) {
|
||||
obfsBuf := make([]byte, testPayloadLen*2)
|
||||
n, err := o.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultFrame := Frame{}
|
||||
err = o.deobfuscate(&resultFrame, obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f, resultFrame)
|
||||
}
|
||||
|
||||
t.Run("plain", func(t *testing.T) {
|
||||
o := Obfuscator{
|
||||
payloadCipher: nil,
|
||||
sessionKey: sessionKey,
|
||||
}
|
||||
runTest(t, o)
|
||||
})
|
||||
|
||||
t.Run("aes-128-gcm", func(t *testing.T) {
|
||||
c, err := aes.NewCipher(sessionKey[:16])
|
||||
assert.NoError(t, err)
|
||||
payloadCipher, err := cipher.NewGCM(c)
|
||||
assert.NoError(t, err)
|
||||
o := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: sessionKey,
|
||||
}
|
||||
runTest(t, o)
|
||||
})
|
||||
|
||||
t.Run("aes-256-gcm", func(t *testing.T) {
|
||||
c, err := aes.NewCipher(sessionKey[:])
|
||||
assert.NoError(t, err)
|
||||
payloadCipher, err := cipher.NewGCM(c)
|
||||
assert.NoError(t, err)
|
||||
o := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: sessionKey,
|
||||
}
|
||||
runTest(t, o)
|
||||
})
|
||||
|
||||
t.Run("chacha20-poly1305", func(t *testing.T) {
|
||||
payloadCipher, err := chacha20poly1305.New(sessionKey[:])
|
||||
assert.NoError(t, err)
|
||||
o := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: sessionKey,
|
||||
}
|
||||
runTest(t, o)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkObfs(b *testing.B) {
|
||||
testPayload := make([]byte, 1024)
|
||||
rand.Read(testPayload)
|
||||
testFrame := &Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
obfsBuf := make([]byte, len(testPayload)*2)
|
||||
|
||||
var key [32]byte
|
||||
rand.Read(key[:])
|
||||
b.Run("AES256GCM", func(b *testing.B) {
|
||||
c, _ := aes.NewCipher(key[:])
|
||||
payloadCipher, _ := cipher.NewGCM(c)
|
||||
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
|
||||
b.SetBytes(int64(len(testFrame.Payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
}
|
||||
})
|
||||
b.Run("AES128GCM", func(b *testing.B) {
|
||||
c, _ := aes.NewCipher(key[:16])
|
||||
payloadCipher, _ := cipher.NewGCM(c)
|
||||
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
b.SetBytes(int64(len(testFrame.Payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
}
|
||||
})
|
||||
b.Run("plain", func(b *testing.B) {
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: nil,
|
||||
sessionKey: key,
|
||||
}
|
||||
b.SetBytes(int64(len(testFrame.Payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
}
|
||||
})
|
||||
b.Run("chacha20Poly1305", func(b *testing.B) {
|
||||
payloadCipher, _ := chacha20poly1305.New(key[:])
|
||||
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
b.SetBytes(int64(len(testFrame.Payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDeobfs(b *testing.B) {
|
||||
testPayload := make([]byte, 1024)
|
||||
rand.Read(testPayload)
|
||||
testFrame := &Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
obfsBuf := make([]byte, len(testPayload)*2)
|
||||
|
||||
var key [32]byte
|
||||
rand.Read(key[:])
|
||||
b.Run("AES256GCM", func(b *testing.B) {
|
||||
c, _ := aes.NewCipher(key[:])
|
||||
payloadCipher, _ := cipher.NewGCM(c)
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
|
||||
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
|
||||
frame := new(Frame)
|
||||
b.SetBytes(int64(n))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||
}
|
||||
})
|
||||
b.Run("AES128GCM", func(b *testing.B) {
|
||||
c, _ := aes.NewCipher(key[:16])
|
||||
payloadCipher, _ := cipher.NewGCM(c)
|
||||
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
|
||||
frame := new(Frame)
|
||||
b.ResetTimer()
|
||||
b.SetBytes(int64(n))
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||
}
|
||||
})
|
||||
b.Run("plain", func(b *testing.B) {
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: nil,
|
||||
sessionKey: key,
|
||||
}
|
||||
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
|
||||
frame := new(Frame)
|
||||
b.ResetTimer()
|
||||
b.SetBytes(int64(n))
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||
}
|
||||
})
|
||||
b.Run("chacha20Poly1305", func(b *testing.B) {
|
||||
payloadCipher, _ := chacha20poly1305.New(key[:])
|
||||
|
||||
obfuscator := Obfuscator{
|
||||
payloadCipher: payloadCipher,
|
||||
sessionKey: key,
|
||||
}
|
||||
|
||||
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
|
||||
|
||||
frame := new(Frame)
|
||||
b.ResetTimer()
|
||||
b.SetBytes(int64(n))
|
||||
for i := 0; i < b.N; i++ {
|
||||
obfuscator.deobfuscate(frame, obfsBuf[:n])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -7,60 +7,41 @@ import (
|
|||
)
|
||||
|
||||
// Valve needs to be universal, across all sessions that belong to a user
|
||||
type LimitedValve struct {
|
||||
// traffic directions from the server's perspective are referred
|
||||
// gabe please don't sue
|
||||
type Valve struct {
|
||||
// traffic directions from the server's perspective are refered
|
||||
// exclusively as rx and tx.
|
||||
// rx is from client to server, tx is from server to client
|
||||
// DO NOT use terms up or down as this is used in usermanager
|
||||
// for bandwidth limiting
|
||||
rxtb *ratelimit.Bucket
|
||||
txtb *ratelimit.Bucket
|
||||
rxtb atomic.Value // *ratelimit.Bucket
|
||||
txtb atomic.Value // *ratelimit.Bucket
|
||||
|
||||
rx *int64
|
||||
tx *int64
|
||||
rxCredit *int64
|
||||
txCredit *int64
|
||||
}
|
||||
|
||||
type UnlimitedValve struct{}
|
||||
|
||||
func MakeValve(rxRate, txRate int64) *LimitedValve {
|
||||
var rx, tx int64
|
||||
v := &LimitedValve{
|
||||
rxtb: ratelimit.NewBucketWithRate(float64(rxRate), rxRate),
|
||||
txtb: ratelimit.NewBucketWithRate(float64(txRate), txRate),
|
||||
rx: &rx,
|
||||
tx: &tx,
|
||||
func MakeValve(rxRate, txRate int64, rxCredit, txCredit *int64) *Valve {
|
||||
v := &Valve{
|
||||
rxCredit: rxCredit,
|
||||
txCredit: txCredit,
|
||||
}
|
||||
v.SetRxRate(rxRate)
|
||||
v.SetTxRate(txRate)
|
||||
return v
|
||||
}
|
||||
|
||||
var UNLIMITED_VALVE = &UnlimitedValve{}
|
||||
func (v *Valve) SetRxRate(rate int64) { v.rxtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
|
||||
func (v *Valve) SetTxRate(rate int64) { v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
|
||||
func (v *Valve) rxWait(n int) { v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
|
||||
func (v *Valve) txWait(n int) { v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
|
||||
func (v *Valve) SetRxCredit(n int64) { atomic.StoreInt64(v.rxCredit, n) }
|
||||
func (v *Valve) SetTxCredit(n int64) { atomic.StoreInt64(v.txCredit, n) }
|
||||
func (v *Valve) GetRxCredit() int64 { return atomic.LoadInt64(v.rxCredit) }
|
||||
func (v *Valve) GetTxCredit() int64 { return atomic.LoadInt64(v.txCredit) }
|
||||
|
||||
func (v *LimitedValve) rxWait(n int) { v.rxtb.Wait(int64(n)) }
|
||||
func (v *LimitedValve) txWait(n int) { v.txtb.Wait(int64(n)) }
|
||||
func (v *LimitedValve) AddRx(n int64) { atomic.AddInt64(v.rx, n) }
|
||||
func (v *LimitedValve) AddTx(n int64) { atomic.AddInt64(v.tx, n) }
|
||||
func (v *LimitedValve) GetRx() int64 { return atomic.LoadInt64(v.rx) }
|
||||
func (v *LimitedValve) GetTx() int64 { return atomic.LoadInt64(v.tx) }
|
||||
func (v *LimitedValve) Nullify() (int64, int64) {
|
||||
rx := atomic.SwapInt64(v.rx, 0)
|
||||
tx := atomic.SwapInt64(v.tx, 0)
|
||||
return rx, tx
|
||||
}
|
||||
// n can be negative
|
||||
func (v *Valve) AddRxCredit(n int64) int64 { return atomic.AddInt64(v.rxCredit, n) }
|
||||
|
||||
func (v *UnlimitedValve) rxWait(n int) {}
|
||||
func (v *UnlimitedValve) txWait(n int) {}
|
||||
func (v *UnlimitedValve) AddRx(n int64) {}
|
||||
func (v *UnlimitedValve) AddTx(n int64) {}
|
||||
func (v *UnlimitedValve) GetRx() int64 { return 0 }
|
||||
func (v *UnlimitedValve) GetTx() int64 { return 0 }
|
||||
func (v *UnlimitedValve) Nullify() (int64, int64) { return 0, 0 }
|
||||
|
||||
type Valve interface {
|
||||
rxWait(n int)
|
||||
txWait(n int)
|
||||
AddRx(n int64)
|
||||
AddTx(n int64)
|
||||
GetRx() int64
|
||||
GetTx() int64
|
||||
Nullify() (int64, int64)
|
||||
}
|
||||
// n can be negative
|
||||
func (v *Valve) AddTxCredit(n int64) int64 { return atomic.AddInt64(v.txCredit, n) }
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrTimeout = errors.New("deadline exceeded")
|
||||
|
||||
type recvBuffer interface {
|
||||
// Read calls' err must be nil | io.EOF | io.ErrShortBuffer
|
||||
// Read should NOT return error on a closed streamBuffer with a non-empty buffer.
|
||||
// Instead, it should behave as if it hasn't been closed. Closure is only relevant
|
||||
// when the buffer is empty.
|
||||
io.ReadCloser
|
||||
Write(*Frame) (toBeClosed bool, err error)
|
||||
SetReadDeadline(time time.Time)
|
||||
}
|
||||
|
||||
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
|
||||
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
|
||||
// a panic will happen.
|
||||
const recvBufferSizeLimit = 1<<31 - 1
|
||||
|
|
@ -2,351 +2,159 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
//"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
acceptBacklog = 1024
|
||||
defaultInactivityTimeout = 30 * time.Second
|
||||
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
|
||||
acceptBacklog = 1024
|
||||
)
|
||||
|
||||
var ErrBrokenSession = errors.New("broken session")
|
||||
var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
||||
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
||||
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
||||
|
||||
type SessionConfig struct {
|
||||
Obfuscator
|
||||
|
||||
// Valve is used to limit transmission rates, and record and limit usage
|
||||
Valve
|
||||
|
||||
Unordered bool
|
||||
|
||||
// A Singleplexing session always has just one stream
|
||||
Singleplex bool
|
||||
|
||||
// maximum size of an obfuscated frame, including headers and overhead
|
||||
MsgOnWireSizeLimit int
|
||||
|
||||
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
|
||||
InactivityTimeout time.Duration
|
||||
}
|
||||
|
||||
// A Session represents a self-contained communication chain between local and remote. It manages its streams,
|
||||
// controls serialisation and encryption of data sent and received using the supplied Obfuscator, and send and receive
|
||||
// data through a manged connection pool filled with underlying connections added to it.
|
||||
type Session struct {
|
||||
id uint32
|
||||
|
||||
SessionConfig
|
||||
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
|
||||
obfs Obfser
|
||||
// Remove TLS header, decrypt and unmarshall multiplexing headers
|
||||
deobfs Deobfser
|
||||
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
|
||||
obfsedRead func(net.Conn, []byte) (int, error)
|
||||
|
||||
// atomic
|
||||
nextStreamID uint32
|
||||
|
||||
// atomic
|
||||
activeStreamCount uint32
|
||||
|
||||
streamsM sync.Mutex
|
||||
streams map[uint32]*Stream
|
||||
// For accepting new streams
|
||||
acceptCh chan *Stream
|
||||
|
||||
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||
recvFramePool sync.Pool
|
||||
|
||||
streamObfsBufPool sync.Pool
|
||||
|
||||
// Switchboard manages all connections to remote
|
||||
sb *switchboard
|
||||
|
||||
// Used for LocalAddr() and RemoteAddr() etc.
|
||||
addrs atomic.Value
|
||||
// For accepting new streams
|
||||
acceptCh chan *Stream
|
||||
|
||||
closed uint32
|
||||
|
||||
terminalMsgSetter sync.Once
|
||||
terminalMsg string
|
||||
|
||||
// the max size passed to Write calls before it splits it into multiple frames
|
||||
// i.e. the max size a piece of data can fit into a Frame.Payload
|
||||
maxStreamUnitWrite int
|
||||
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
|
||||
streamSendBufferSize int
|
||||
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
|
||||
// switchboard.deplex)
|
||||
connReceiveBufferSize int
|
||||
broken uint32
|
||||
die chan struct{}
|
||||
suicide sync.Once
|
||||
}
|
||||
|
||||
func MakeSession(id uint32, config SessionConfig) *Session {
|
||||
func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) *Session {
|
||||
sesh := &Session{
|
||||
id: id,
|
||||
SessionConfig: config,
|
||||
nextStreamID: 1,
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
|
||||
streams: map[uint32]*Stream{},
|
||||
id: id,
|
||||
obfs: obfs,
|
||||
deobfs: deobfs,
|
||||
obfsedRead: obfsedRead,
|
||||
nextStreamID: 1,
|
||||
streams: make(map[uint32]*Stream),
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
die: make(chan struct{}),
|
||||
}
|
||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||
|
||||
if config.Valve == nil {
|
||||
sesh.Valve = UNLIMITED_VALVE
|
||||
}
|
||||
if config.MsgOnWireSizeLimit <= 0 {
|
||||
sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
|
||||
}
|
||||
if config.InactivityTimeout == 0 {
|
||||
sesh.InactivityTimeout = defaultInactivityTimeout
|
||||
}
|
||||
|
||||
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - maxExtraLen
|
||||
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit
|
||||
sesh.connReceiveBufferSize = 20480 // for backwards compatibility
|
||||
|
||||
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
|
||||
b := make([]byte, sesh.streamSendBufferSize)
|
||||
return &b
|
||||
}}
|
||||
|
||||
sesh.sb = makeSwitchboard(sesh)
|
||||
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
||||
sesh.sb = makeSwitchboard(sesh, valve)
|
||||
go sesh.timeoutAfter(30 * time.Second)
|
||||
return sesh
|
||||
}
|
||||
|
||||
func (sesh *Session) GetSessionKey() [32]byte {
|
||||
return sesh.sessionKey
|
||||
}
|
||||
|
||||
func (sesh *Session) streamCountIncr() uint32 {
|
||||
return atomic.AddUint32(&sesh.activeStreamCount, 1)
|
||||
}
|
||||
func (sesh *Session) streamCountDecr() uint32 {
|
||||
return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0))
|
||||
}
|
||||
func (sesh *Session) streamCount() uint32 {
|
||||
return atomic.LoadUint32(&sesh.activeStreamCount)
|
||||
}
|
||||
|
||||
// AddConnection is used to add an underlying connection to the connection pool
|
||||
func (sesh *Session) AddConnection(conn net.Conn) {
|
||||
sesh.sb.addConn(conn)
|
||||
addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()}
|
||||
sesh.addrs.Store(addrs)
|
||||
}
|
||||
|
||||
// OpenStream is similar to net.Dial. It opens up a new stream
|
||||
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||
if sesh.IsClosed() {
|
||||
select {
|
||||
case <-sesh.die:
|
||||
return nil, ErrBrokenSession
|
||||
default:
|
||||
}
|
||||
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
||||
// Because atomic.AddUint32 returns the value after incrementation
|
||||
if sesh.Singleplex && id > 1 {
|
||||
// if there are more than one streams, which shouldn't happen if we are
|
||||
// singleplexing
|
||||
return nil, errNoMultiplex
|
||||
}
|
||||
stream := makeStream(sesh, id)
|
||||
stream := makeStream(id, sesh)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[id] = stream
|
||||
sesh.streamsM.Unlock()
|
||||
sesh.streamCountIncr()
|
||||
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
||||
//log.Printf("Opening stream %v\n", id)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Accept is similar to net.Listener's Accept(). It blocks and returns an incoming stream
|
||||
func (sesh *Session) Accept() (net.Conn, error) {
|
||||
if sesh.IsClosed() {
|
||||
func (sesh *Session) AcceptStream() (*Stream, error) {
|
||||
select {
|
||||
case <-sesh.die:
|
||||
return nil, ErrBrokenSession
|
||||
case stream := <-sesh.acceptCh:
|
||||
return stream, nil
|
||||
}
|
||||
stream := <-sesh.acceptCh
|
||||
if stream == nil {
|
||||
return nil, ErrBrokenSession
|
||||
}
|
||||
log.Tracef("stream %v of session %v accepted", stream.id, sesh.id)
|
||||
return stream, nil
|
||||
|
||||
}
|
||||
|
||||
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
||||
}
|
||||
_ = s.recvBuf.Close() // recvBuf.Close should not return error
|
||||
|
||||
if active {
|
||||
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||
|
||||
// Notify remote that this stream is closed
|
||||
common.CryptoRandRead((*tmpBuf)[:1])
|
||||
padLen := int((*tmpBuf)[0]) + 1
|
||||
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
|
||||
common.CryptoRandRead(payload)
|
||||
|
||||
// must be holding s.wirtingM on entry
|
||||
s.writingFrame.Closing = closingStream
|
||||
s.writingFrame.Payload = payload
|
||||
|
||||
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
|
||||
sesh.streamObfsBufPool.Put(tmpBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Tracef("stream %v actively closed.", s.id)
|
||||
} else {
|
||||
log.Tracef("stream %v passively closed", s.id)
|
||||
}
|
||||
|
||||
// We set it as nil to signify that the stream id had existed before.
|
||||
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
||||
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
||||
func (sesh *Session) delStream(id uint32) {
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[s.id] = nil
|
||||
delete(sesh.streams, id)
|
||||
if len(sesh.streams) == 0 {
|
||||
go sesh.timeoutAfter(30 * time.Second)
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
if sesh.streamCountDecr() == 0 {
|
||||
if sesh.Singleplex {
|
||||
return sesh.Close()
|
||||
} else {
|
||||
log.Debugf("session %v has no active stream left", sesh.id)
|
||||
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
|
||||
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
|
||||
// stream and then writes to the stream buffer
|
||||
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||
frame := sesh.recvFramePool.Get().(*Frame)
|
||||
defer sesh.recvFramePool.Put(frame)
|
||||
|
||||
err := sesh.deobfuscate(frame, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
||||
}
|
||||
|
||||
if frame.Closing == closingSession {
|
||||
sesh.SetTerminalMsg("Received a closing notification frame")
|
||||
return sesh.passiveClose()
|
||||
}
|
||||
|
||||
// either fetch an existing stream or instantiate a new stream and put it in the dict, and return it
|
||||
func (sesh *Session) getStream(id uint32, closingFrame bool) *Stream {
|
||||
// it would have been neater to use defer Unlock(), however it gives
|
||||
// non-negligable overhead and this function is performance critical
|
||||
sesh.streamsM.Lock()
|
||||
if sesh.IsClosed() {
|
||||
stream := sesh.streams[id]
|
||||
if stream != nil {
|
||||
sesh.streamsM.Unlock()
|
||||
return ErrBrokenSession
|
||||
}
|
||||
existingStream, existing := sesh.streams[frame.StreamID]
|
||||
if existing {
|
||||
sesh.streamsM.Unlock()
|
||||
if existingStream == nil {
|
||||
// this is when the stream existed before but has since been closed. We do nothing
|
||||
return stream
|
||||
} else {
|
||||
if closingFrame {
|
||||
// If the stream has been closed and the current frame is a closing frame,
|
||||
// we return nil
|
||||
sesh.streamsM.Unlock()
|
||||
return nil
|
||||
}
|
||||
return existingStream.recvFrame(frame)
|
||||
} else {
|
||||
newStream := makeStream(sesh, frame.StreamID)
|
||||
sesh.streams[frame.StreamID] = newStream
|
||||
sesh.acceptCh <- newStream
|
||||
sesh.streamsM.Unlock()
|
||||
// new stream
|
||||
sesh.streamCountIncr()
|
||||
return newStream.recvFrame(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func (sesh *Session) SetTerminalMsg(msg string) {
|
||||
log.Debug("terminal message set to " + msg)
|
||||
sesh.terminalMsgSetter.Do(func() {
|
||||
sesh.terminalMsg = msg
|
||||
})
|
||||
}
|
||||
|
||||
func (sesh *Session) TerminalMsg() string {
|
||||
return sesh.terminalMsg
|
||||
}
|
||||
|
||||
func (sesh *Session) closeSession() error {
|
||||
if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) {
|
||||
log.Debugf("session %v has already been closed", sesh.id)
|
||||
return errRepeatSessionClosing
|
||||
}
|
||||
|
||||
sesh.streamsM.Lock()
|
||||
close(sesh.acceptCh)
|
||||
for id, stream := range sesh.streams {
|
||||
if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
|
||||
_ = stream.recvBuf.Close() // will not block
|
||||
delete(sesh.streams, id)
|
||||
sesh.streamCountDecr()
|
||||
} else {
|
||||
stream = makeStream(id, sesh)
|
||||
sesh.streams[id] = stream
|
||||
sesh.acceptCh <- stream
|
||||
//log.Printf("Adding stream %v\n", id)
|
||||
sesh.streamsM.Unlock()
|
||||
return stream
|
||||
}
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sesh *Session) passiveClose() error {
|
||||
log.Debugf("attempting to passively close session %v", sesh.id)
|
||||
err := sesh.closeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sesh.sb.closeAll()
|
||||
log.Debugf("session %v closed gracefully", sesh.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sesh *Session) Close() error {
|
||||
log.Debugf("attempting to actively close session %v", sesh.id)
|
||||
err := sesh.closeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
// Because closing a closed channel causes panic
|
||||
sesh.suicide.Do(func() { close(sesh.die) })
|
||||
atomic.StoreUint32(&sesh.broken, 1)
|
||||
sesh.streamsM.Lock()
|
||||
for id, stream := range sesh.streams {
|
||||
// If we call stream.Close() here, streamsM will result in a deadlock
|
||||
// because stream.Close calls sesh.delStream, which locks the mutex.
|
||||
// so we need to implement a method of stream that closes the stream without calling
|
||||
// sesh.delStream
|
||||
go stream.closeNoDelMap()
|
||||
delete(sesh.streams, id)
|
||||
}
|
||||
// we send a notice frame telling remote to close the session
|
||||
sesh.streamsM.Unlock()
|
||||
|
||||
buf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||
common.CryptoRandRead((*buf)[:1])
|
||||
padLen := int((*buf)[0]) + 1
|
||||
payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength]
|
||||
common.CryptoRandRead(payload)
|
||||
|
||||
f := &Frame{
|
||||
StreamID: 0xffffffff,
|
||||
Seq: 0,
|
||||
Closing: closingSession,
|
||||
Payload: payload,
|
||||
}
|
||||
i, err := sesh.obfuscate(f, *buf, frameHeaderLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sesh.sb.send((*buf)[:i], new(net.Conn))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sesh.sb.closeAll()
|
||||
log.Debugf("session %v closed gracefully", sesh.id)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (sesh *Session) IsClosed() bool {
|
||||
return atomic.LoadUint32(&sesh.closed) == 1
|
||||
func (sesh *Session) IsBroken() bool {
|
||||
return atomic.LoadUint32(&sesh.broken) == 1
|
||||
}
|
||||
|
||||
func (sesh *Session) checkTimeout() {
|
||||
if sesh.streamCount() == 0 && !sesh.IsClosed() {
|
||||
sesh.SetTerminalMsg("timeout")
|
||||
func (sesh *Session) timeoutAfter(to time.Duration) {
|
||||
time.Sleep(to)
|
||||
sesh.streamsM.Lock()
|
||||
if len(sesh.streams) == 0 && !sesh.IsBroken() {
|
||||
sesh.streamsM.Unlock()
|
||||
sesh.Close()
|
||||
} else {
|
||||
sesh.streamsM.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (sesh *Session) Addr() net.Addr { return sesh.addrs.Load().([]net.Addr)[0] }
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
//go:build gofuzz
|
||||
// +build gofuzz
|
||||
|
||||
package multiplex
|
||||
|
||||
func setupSesh_fuzz(unordered bool) *Session {
|
||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, [32]byte{})
|
||||
|
||||
seshConfig := SessionConfig{
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: unordered,
|
||||
}
|
||||
return MakeSession(0, seshConfig)
|
||||
}
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
sesh := setupSesh_fuzz(false)
|
||||
err := sesh.recvDataFromRemote(data)
|
||||
if err == nil {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
@ -1,640 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/connutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var seshConfigs = map[string]SessionConfig{
|
||||
"ordered": {},
|
||||
"unordered": {Unordered: true},
|
||||
}
|
||||
var encryptionMethods = map[string]byte{
|
||||
"plain": EncryptionMethodPlain,
|
||||
"aes-256-gcm": EncryptionMethodAES256GCM,
|
||||
"aes-128-gcm": EncryptionMethodAES128GCM,
|
||||
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
||||
}
|
||||
|
||||
const testPayloadLen = 1024
|
||||
const obfsBufLen = testPayloadLen * 2
|
||||
|
||||
func TestRecvDataFromRemote(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
seshConfig := seshConfig
|
||||
t.Run(seshType, func(t *testing.T) {
|
||||
var err error
|
||||
seshConfig.Obfuscator, err = MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make obfuscator: %v", err)
|
||||
}
|
||||
t.Run("initial frame", func(t *testing.T) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
f := Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
make([]byte, testPayloadLen),
|
||||
}
|
||||
rand.Read(f.Payload)
|
||||
n, err := sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
stream, err := sesh.Accept()
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPayload := make([]byte, testPayloadLen)
|
||||
_, err = stream.Read(resultPayload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f.Payload, resultPayload)
|
||||
})
|
||||
|
||||
t.Run("two frames in order", func(t *testing.T) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
f := Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
make([]byte, testPayloadLen),
|
||||
}
|
||||
rand.Read(f.Payload)
|
||||
n, err := sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
stream, err := sesh.Accept()
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPayload := make([]byte, testPayloadLen)
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f.Payload, resultPayload)
|
||||
|
||||
f.Seq += 1
|
||||
rand.Read(f.Payload)
|
||||
n, err = sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f.Payload, resultPayload)
|
||||
})
|
||||
|
||||
t.Run("two frames in order", func(t *testing.T) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
f := Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
make([]byte, testPayloadLen),
|
||||
}
|
||||
rand.Read(f.Payload)
|
||||
n, err := sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
stream, err := sesh.Accept()
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPayload := make([]byte, testPayloadLen)
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f.Payload, resultPayload)
|
||||
|
||||
f.Seq += 1
|
||||
rand.Read(f.Payload)
|
||||
n, err = sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, f.Payload, resultPayload)
|
||||
})
|
||||
|
||||
if seshType == "ordered" {
|
||||
t.Run("frames out of order", func(t *testing.T) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
f := Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
nil,
|
||||
}
|
||||
|
||||
// First frame
|
||||
seq0 := make([]byte, testPayloadLen)
|
||||
rand.Read(seq0)
|
||||
f.Seq = 0
|
||||
f.Payload = seq0
|
||||
n, err := sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Third frame
|
||||
seq2 := make([]byte, testPayloadLen)
|
||||
rand.Read(seq2)
|
||||
f.Seq = 2
|
||||
f.Payload = seq2
|
||||
n, err = sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Second frame
|
||||
seq1 := make([]byte, testPayloadLen)
|
||||
rand.Read(seq1)
|
||||
f.Seq = 1
|
||||
f.Payload = seq1
|
||||
n, err = sesh.obfuscate(&f, obfsBuf, 0)
|
||||
assert.NoError(t, err)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Expect things to receive in order
|
||||
stream, err := sesh.Accept()
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPayload := make([]byte, testPayloadLen)
|
||||
|
||||
// First
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, seq0, resultPayload)
|
||||
|
||||
// Second
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, seq1, resultPayload)
|
||||
|
||||
// Third
|
||||
_, err = io.ReadFull(stream, resultPayload)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, seq2, resultPayload)
|
||||
})
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
seshConfig := seshConfigs["ordered"]
|
||||
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
f1 := &Frame{
|
||||
1,
|
||||
0,
|
||||
closingNothing,
|
||||
testPayload,
|
||||
}
|
||||
// create stream 1
|
||||
n, _ := sesh.obfuscate(f1, obfsBuf, 0)
|
||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||
}
|
||||
sesh.streamsM.Lock()
|
||||
_, ok := sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if !ok {
|
||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1")
|
||||
}
|
||||
|
||||
// create stream 2
|
||||
f2 := &Frame{
|
||||
2,
|
||||
0,
|
||||
closingNothing,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.obfuscate(f2, obfsBuf, 0)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||
}
|
||||
sesh.streamsM.Lock()
|
||||
s2M, ok := sesh.streams[f2.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s2M == nil || !ok {
|
||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||
}
|
||||
if sesh.streamCount() != 2 {
|
||||
t.Error("stream count isn't 2")
|
||||
}
|
||||
|
||||
// close stream 1
|
||||
f1CloseStream := &Frame{
|
||||
1,
|
||||
1,
|
||||
closingStream,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
sesh.streamsM.Lock()
|
||||
s1M, _ := sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s1M != nil {
|
||||
t.Fatal("stream 1 still exist after receiving stream close")
|
||||
}
|
||||
s1, _ := sesh.Accept()
|
||||
if !s1.(*Stream).isClosed() {
|
||||
t.Fatal("stream 1 not marked as closed")
|
||||
}
|
||||
payloadBuf := make([]byte, testPayloadLen)
|
||||
_, err = s1.Read(payloadBuf)
|
||||
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Fatalf("failed to read from stream 1 after closing: %v", err)
|
||||
}
|
||||
s2, _ := sesh.Accept()
|
||||
if s2.(*Stream).isClosed() {
|
||||
t.Fatal("stream 2 shouldn't be closed")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1 after stream 1 closed")
|
||||
}
|
||||
|
||||
// close stream 1 again
|
||||
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||
}
|
||||
sesh.streamsM.Lock()
|
||||
s1M, _ = sesh.streams[f1.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if s1M != nil {
|
||||
t.Error("stream 1 exists after receiving stream close for the second time")
|
||||
}
|
||||
streamCount := sesh.streamCount()
|
||||
if streamCount != 1 {
|
||||
t.Errorf("stream count is %v after stream 1 closed twice, expected 1", streamCount)
|
||||
}
|
||||
|
||||
// close session
|
||||
fCloseSession := &Frame{
|
||||
StreamID: 0xffffffff,
|
||||
Seq: 0,
|
||||
Closing: closingSession,
|
||||
Payload: testPayload,
|
||||
}
|
||||
n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving session closing frame: %v", err)
|
||||
}
|
||||
if !sesh.IsClosed() {
|
||||
t.Error("session not closed after receiving signal")
|
||||
}
|
||||
if !s2.(*Stream).isClosed() {
|
||||
t.Error("stream 2 isn't closed after session closed")
|
||||
}
|
||||
if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Error("failed to read from stream 2 after session closed")
|
||||
}
|
||||
if _, err := s2.Write(testPayload); err == nil {
|
||||
t.Error("can still write to stream 2 after session closed")
|
||||
}
|
||||
if sesh.streamCount() != 0 {
|
||||
t.Error("stream count isn't 0 after session closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||
// Tests for when the closing frame of a stream is received first before any data frame
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
seshConfig := seshConfigs["ordered"]
|
||||
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
// receive stream 1 closing first
|
||||
f1CloseStream := &Frame{
|
||||
1,
|
||||
1,
|
||||
closingStream,
|
||||
testPayload,
|
||||
}
|
||||
n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0)
|
||||
err := sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||
}
|
||||
sesh.streamsM.Lock()
|
||||
_, ok := sesh.streams[f1CloseStream.StreamID]
|
||||
sesh.streamsM.Unlock()
|
||||
if !ok {
|
||||
t.Fatal("stream 1 doesn't exist")
|
||||
}
|
||||
if sesh.streamCount() != 1 {
|
||||
t.Error("stream count isn't 1 after stream 1 received")
|
||||
}
|
||||
|
||||
// receive data frame of stream 1 after receiving the closing frame
|
||||
f1 := &Frame{
|
||||
1,
|
||||
0,
|
||||
closingNothing,
|
||||
testPayload,
|
||||
}
|
||||
n, _ = sesh.obfuscate(f1, obfsBuf, 0)
|
||||
err = sesh.recvDataFromRemote(obfsBuf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||
}
|
||||
s1, err := sesh.Accept()
|
||||
if err != nil {
|
||||
t.Fatal("failed to accept stream 1 after receiving it")
|
||||
}
|
||||
payloadBuf := make([]byte, testPayloadLen)
|
||||
if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
|
||||
t.Error("failed to read from steam 1")
|
||||
}
|
||||
if !s1.(*Stream).isClosed() {
|
||||
t.Error("s1 isn't closed")
|
||||
}
|
||||
if sesh.streamCount() != 0 {
|
||||
t.Error("stream count isn't 0 after stream 1 closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelStreams(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
seshConfig := seshConfig
|
||||
t.Run(seshType, func(t *testing.T) {
|
||||
seshConfig.Obfuscator = obfuscator
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
numStreams := acceptBacklog
|
||||
seqs := make([]*uint64, numStreams)
|
||||
for i := range seqs {
|
||||
seqs[i] = new(uint64)
|
||||
}
|
||||
randFrame := func() *Frame {
|
||||
id := rand.Intn(numStreams)
|
||||
return &Frame{
|
||||
uint32(id),
|
||||
atomic.AddUint64(seqs[id], 1) - 1,
|
||||
uint8(rand.Intn(2)),
|
||||
[]byte{1, 2, 3, 4},
|
||||
}
|
||||
}
|
||||
|
||||
const numOfTests = 5000
|
||||
tests := make([]struct {
|
||||
name string
|
||||
frame *Frame
|
||||
}, numOfTests)
|
||||
for i := range tests {
|
||||
tests[i].name = strconv.Itoa(i)
|
||||
tests[i].frame = randFrame()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, tc := range tests {
|
||||
wg.Add(1)
|
||||
go func(frame *Frame) {
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
n, _ := sesh.obfuscate(frame, obfsBuf, 0)
|
||||
obfsBuf = obfsBuf[0:n]
|
||||
|
||||
err := sesh.recvDataFromRemote(obfsBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wg.Done()
|
||||
}(tc.frame)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
sc := int(sesh.streamCount())
|
||||
var count int
|
||||
sesh.streamsM.Lock()
|
||||
for _, s := range sesh.streams {
|
||||
if s != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
if sc != count {
|
||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_SetReadDeadline(t *testing.T) {
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
seshConfig := seshConfig
|
||||
t.Run(seshType, func(t *testing.T) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
sesh.AddConnection(connutil.Discard())
|
||||
|
||||
t.Run("read after deadline set", func(t *testing.T) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
_, err := stream.Read(make([]byte, 1))
|
||||
if err != ErrTimeout {
|
||||
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unblock when deadline passed", func(t *testing.T) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, _ = stream.Read(make([]byte, 1))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Read did not unblock after deadline has passed")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSession_timeoutAfter(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
seshConfig := seshConfig
|
||||
t.Run(seshType, func(t *testing.T) {
|
||||
seshConfig.Obfuscator = obfuscator
|
||||
seshConfig.InactivityTimeout = 100 * time.Millisecond
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return sesh.IsClosed()
|
||||
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecvDataFromRemote(b *testing.B) {
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
rand.Read(testPayload)
|
||||
f := Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic
|
||||
for name, ep := range encryptionMethods {
|
||||
ep := ep
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
b.Run(seshType, func(b *testing.B) {
|
||||
f := f
|
||||
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
go func() {
|
||||
stream, _ := sesh.Accept()
|
||||
io.Copy(ioutil.Discard, stream)
|
||||
}()
|
||||
|
||||
binaryFrames := [maxIter][]byte{}
|
||||
for i := 0; i < maxIter; i++ {
|
||||
obfsBuf := make([]byte, obfsBufLen)
|
||||
n, _ := sesh.obfuscate(&f, obfsBuf, 0)
|
||||
binaryFrames[i] = obfsBuf[:n]
|
||||
f.Seq++
|
||||
}
|
||||
|
||||
b.SetBytes(int64(len(f.Payload)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sesh.recvDataFromRemote(binaryFrames[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMultiStreamWrite(b *testing.B) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
testPayload := make([]byte, testPayloadLen)
|
||||
|
||||
for name, ep := range encryptionMethods {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
b.Run(seshType, func(b *testing.B) {
|
||||
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
sesh.AddConnection(connutil.Discard())
|
||||
b.ResetTimer()
|
||||
b.SetBytes(testPayloadLen)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
for pb.Next() {
|
||||
stream.Write(testPayload)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLatency(b *testing.B) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
for name, ep := range encryptionMethods {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for seshType, seshConfig := range seshConfigs {
|
||||
b.Run(seshType, func(b *testing.B) {
|
||||
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
|
||||
clientSesh := MakeSession(0, seshConfig)
|
||||
serverSesh := MakeSession(0, seshConfig)
|
||||
|
||||
c, s := net.Pipe()
|
||||
clientSesh.AddConnection(c)
|
||||
serverSesh.AddConnection(s)
|
||||
|
||||
buf := make([]byte, 64)
|
||||
clientStream, _ := clientSesh.OpenStream()
|
||||
clientStream.Write(buf)
|
||||
serverStream, _ := serverSesh.Accept()
|
||||
io.ReadFull(serverStream, buf)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
clientStream.Write(buf)
|
||||
io.ReadFull(serverStream, buf)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -2,210 +2,156 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
//"log"
|
||||
"math"
|
||||
prand "math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var ErrBrokenStream = errors.New("broken stream")
|
||||
|
||||
// Stream implements net.Conn. It represents an optionally-ordered, full-duplex, self-contained connection.
|
||||
// If the session it belongs to runs in ordered mode, it provides ordering guarantee regardless of the underlying
|
||||
// connection used.
|
||||
// If the underlying connections the session uses are reliable, Stream is reliable. If they are not, Stream does not
|
||||
// guarantee reliability.
|
||||
type Stream struct {
|
||||
id uint32
|
||||
|
||||
session *Session
|
||||
|
||||
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
|
||||
// been read by the consumer through Read or WriteTo.
|
||||
recvBuf recvBuffer
|
||||
// Explanations of the following 4 fields can be found in frameSorter.go
|
||||
nextRecvSeq uint32
|
||||
rev int
|
||||
sh sorterHeap
|
||||
wrapMode bool
|
||||
|
||||
writingM sync.Mutex
|
||||
writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom
|
||||
// New frames are received through newFrameCh by frameSorter
|
||||
newFrameCh chan *Frame
|
||||
// sortedBufCh are order-sorted data ready to be read raw
|
||||
sortedBufCh chan []byte
|
||||
|
||||
// atomic
|
||||
closed uint32
|
||||
nextSendSeq uint32
|
||||
|
||||
// When we want order guarantee (i.e. session.Unordered is false),
|
||||
// we assign each stream a fixed underlying connection.
|
||||
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
|
||||
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
|
||||
// so it won't have to use its priority queue to sort it.
|
||||
// This is not used in unordered connection mode
|
||||
assignedConn net.Conn
|
||||
writingM sync.RWMutex
|
||||
|
||||
readFromTimeout time.Duration
|
||||
// close(die) is used to notify different goroutines that this stream is closing
|
||||
die chan struct{}
|
||||
heliumMask sync.Once // my personal fav
|
||||
}
|
||||
|
||||
func makeStream(sesh *Session, id uint32) *Stream {
|
||||
func makeStream(id uint32, sesh *Session) *Stream {
|
||||
stream := &Stream{
|
||||
id: id,
|
||||
session: sesh,
|
||||
writingFrame: Frame{
|
||||
StreamID: id,
|
||||
Seq: 0,
|
||||
Closing: closingNothing,
|
||||
},
|
||||
id: id,
|
||||
session: sesh,
|
||||
die: make(chan struct{}),
|
||||
sh: []*frameNode{},
|
||||
newFrameCh: make(chan *Frame, 1024),
|
||||
sortedBufCh: make(chan []byte, 1024),
|
||||
}
|
||||
|
||||
if sesh.Unordered {
|
||||
stream.recvBuf = NewDatagramBufferedPipe()
|
||||
} else {
|
||||
stream.recvBuf = NewStreamBuffer()
|
||||
}
|
||||
|
||||
go stream.recvNewFrame()
|
||||
return stream
|
||||
}
|
||||
|
||||
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
||||
|
||||
// receive a readily deobfuscated Frame so its payload can later be Read
|
||||
func (s *Stream) recvFrame(frame *Frame) error {
|
||||
toBeClosed, err := s.recvBuf.Write(frame)
|
||||
if toBeClosed {
|
||||
err = s.passiveClose()
|
||||
if errors.Is(err, errRepeatStreamClosing) {
|
||||
log.Debug(err)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Read implements io.Read
|
||||
func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||
//log.Tracef("attempting to read from stream %v", s.id)
|
||||
func (stream *Stream) Read(buf []byte) (n int, err error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err = s.recvBuf.Read(buf)
|
||||
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
||||
if err == io.EOF {
|
||||
return n, ErrBrokenStream
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
|
||||
cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf)
|
||||
s.writingFrame.Seq++
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
|
||||
if err != nil {
|
||||
if err == errBrokenSwitchboard {
|
||||
s.session.SetTerminalMsg(err.Error())
|
||||
s.session.passiveClose()
|
||||
select {
|
||||
case <-stream.die:
|
||||
return 0, ErrBrokenStream
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write implements io.Write
|
||||
func (s *Stream) Write(in []byte) (n int, err error) {
|
||||
s.writingM.Lock()
|
||||
defer s.writingM.Unlock()
|
||||
if s.isClosed() {
|
||||
select {
|
||||
case <-stream.die:
|
||||
return 0, ErrBrokenStream
|
||||
case data := <-stream.sortedBufCh:
|
||||
if len(data) == 0 {
|
||||
stream.passiveClose()
|
||||
return 0, ErrBrokenStream
|
||||
}
|
||||
if len(buf) < len(data) {
|
||||
return 0, errors.New("buf too small")
|
||||
}
|
||||
copy(buf, data)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
for n < len(in) {
|
||||
var framePayload []byte
|
||||
if len(in)-n <= s.session.maxStreamUnitWrite {
|
||||
// if we can fit remaining data of in into one frame
|
||||
framePayload = in[n:]
|
||||
} else {
|
||||
// if we have to split
|
||||
if s.session.Unordered {
|
||||
// but we are not allowed to
|
||||
err = io.ErrShortBuffer
|
||||
return
|
||||
}
|
||||
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
||||
}
|
||||
s.writingFrame.Payload = framePayload
|
||||
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||
err = s.obfuscateAndSend(*buf, 0)
|
||||
s.session.streamObfsBufPool.Put(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += len(framePayload)
|
||||
}
|
||||
|
||||
func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||
// RWMutex used here isn't really for RW.
|
||||
// we use it to exploit the fact that RLock doesn't create contention.
|
||||
// The use of RWMutex is so that the stream will not actively close
|
||||
// in the middle of the execution of Write. This may cause the closing frame
|
||||
// to be sent before the data frame and cause loss of packet.
|
||||
stream.writingM.RLock()
|
||||
select {
|
||||
case <-stream.die:
|
||||
stream.writingM.RUnlock()
|
||||
return 0, ErrBrokenStream
|
||||
default:
|
||||
}
|
||||
|
||||
f := &Frame{
|
||||
StreamID: stream.id,
|
||||
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
|
||||
Closing: 0,
|
||||
Payload: in,
|
||||
}
|
||||
|
||||
tlsRecord, err := stream.session.obfs(f)
|
||||
if err != nil {
|
||||
stream.writingM.RUnlock()
|
||||
return 0, err
|
||||
}
|
||||
n, err = stream.session.sb.send(tlsRecord)
|
||||
stream.writingM.RUnlock()
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
||||
// for readFromTimeout amount of time
|
||||
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
if s.readFromTimeout != 0 {
|
||||
if rder, ok := r.(net.Conn); !ok {
|
||||
log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline")
|
||||
} else {
|
||||
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
|
||||
}
|
||||
}
|
||||
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
|
||||
if er != nil {
|
||||
return n, er
|
||||
}
|
||||
|
||||
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
|
||||
// to check that here
|
||||
if s.isClosed() {
|
||||
return n, ErrBrokenStream
|
||||
}
|
||||
|
||||
s.writingM.Lock()
|
||||
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
|
||||
err = s.obfuscateAndSend(*buf, frameHeaderLength)
|
||||
s.writingM.Unlock()
|
||||
s.session.streamObfsBufPool.Put(buf)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(read)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) passiveClose() error {
|
||||
return s.session.closeStream(s, false)
|
||||
// only close locally. Used when the stream close is notified by the remote
|
||||
func (stream *Stream) passiveClose() {
|
||||
stream.heliumMask.Do(func() { close(stream.die) })
|
||||
stream.session.delStream(stream.id)
|
||||
//log.Printf("%v passive closing\n", stream.id)
|
||||
}
|
||||
|
||||
// active close. Close locally and tell the remote that this stream is being closed
|
||||
func (s *Stream) Close() error {
|
||||
s.writingM.Lock()
|
||||
defer s.writingM.Unlock()
|
||||
func (stream *Stream) Close() error {
|
||||
|
||||
return s.session.closeStream(s, true)
|
||||
stream.writingM.Lock()
|
||||
select {
|
||||
case <-stream.die:
|
||||
stream.writingM.Unlock()
|
||||
return errors.New("Already Closed")
|
||||
default:
|
||||
}
|
||||
stream.heliumMask.Do(func() { close(stream.die) })
|
||||
|
||||
// Notify remote that this stream is closed
|
||||
prand.Seed(int64(stream.id))
|
||||
padLen := int(math.Floor(prand.Float64()*200 + 300))
|
||||
pad := make([]byte, padLen)
|
||||
prand.Read(pad)
|
||||
f := &Frame{
|
||||
StreamID: stream.id,
|
||||
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
|
||||
Closing: 1,
|
||||
Payload: pad,
|
||||
}
|
||||
tlsRecord, _ := stream.session.obfs(f)
|
||||
stream.session.sb.send(tlsRecord)
|
||||
|
||||
stream.session.delStream(stream.id)
|
||||
//log.Printf("%v actively closed\n", stream.id)
|
||||
stream.writingM.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
|
||||
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
|
||||
|
||||
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
|
||||
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
|
||||
|
||||
var errNotImplemented = errors.New("Not implemented")
|
||||
|
||||
// the following functions are purely for implementing net.Conn interface.
|
||||
// they are not used
|
||||
// TODO: implement the following
|
||||
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|
||||
// Same as passiveClose() but no call to session.delStream.
|
||||
// This is called in session.Close() to avoid mutex deadlock
|
||||
// We don't notify the remote because session.Close() is always
|
||||
// called when the session is passively closed
|
||||
func (stream *Stream) closeNoDelMap() {
|
||||
stream.heliumMask.Do(func() { close(stream.die) })
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
// The data is multiplexed through several TCP connections, therefore the
|
||||
// order of arrival is not guaranteed. A stream's first packet may be sent through
|
||||
// connection0 and its second packet may be sent through connection1. Although both
|
||||
// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the
|
||||
// remote side before packet0. Cloak have to therefore sequence the packets so that they
|
||||
// arrive in order as they were sent by the proxy software
|
||||
//
|
||||
// Cloak packets will have a 64-bit sequence number on them, so we know in which order
|
||||
// they should be sent to the proxy software. The code in this file provides buffering and sorting.
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sorterHeap []*Frame
|
||||
|
||||
func (sh sorterHeap) Less(i, j int) bool {
|
||||
return sh[i].Seq < sh[j].Seq
|
||||
}
|
||||
func (sh sorterHeap) Len() int {
|
||||
return len(sh)
|
||||
}
|
||||
func (sh sorterHeap) Swap(i, j int) {
|
||||
sh[i], sh[j] = sh[j], sh[i]
|
||||
}
|
||||
|
||||
func (sh *sorterHeap) Push(x interface{}) {
|
||||
*sh = append(*sh, x.(*Frame))
|
||||
}
|
||||
|
||||
func (sh *sorterHeap) Pop() interface{} {
|
||||
old := *sh
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*sh = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
type streamBuffer struct {
|
||||
recvM sync.Mutex
|
||||
|
||||
nextRecvSeq uint64
|
||||
sh sorterHeap
|
||||
|
||||
buf *streamBufferedPipe
|
||||
}
|
||||
|
||||
// streamBuffer is a wrapper around streamBufferedPipe.
|
||||
// Its main function is to sort frames in order, and wait for frames to arrive
|
||||
// if they have arrived out-of-order. Then it writes the payload of frames into
|
||||
// a streamBufferedPipe.
|
||||
func NewStreamBuffer() *streamBuffer {
|
||||
sb := &streamBuffer{
|
||||
sh: []*Frame{},
|
||||
buf: NewStreamBufferedPipe(),
|
||||
}
|
||||
return sb
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) {
|
||||
sb.recvM.Lock()
|
||||
defer sb.recvM.Unlock()
|
||||
// when there'fs no ooo packages in heap and we receive the next package in order
|
||||
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
|
||||
if f.Closing != closingNothing {
|
||||
return true, nil
|
||||
} else {
|
||||
sb.buf.Write(f.Payload)
|
||||
sb.nextRecvSeq += 1
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if f.Seq < sb.nextRecvSeq {
|
||||
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
||||
}
|
||||
|
||||
saved := *f
|
||||
saved.Payload = make([]byte, len(f.Payload))
|
||||
copy(saved.Payload, f.Payload)
|
||||
heap.Push(&sb.sh, &saved)
|
||||
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
||||
f = heap.Pop(&sb.sh).(*Frame)
|
||||
if f.Closing != closingNothing {
|
||||
return true, nil
|
||||
} else {
|
||||
sb.buf.Write(f.Payload)
|
||||
sb.nextRecvSeq += 1
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) Read(buf []byte) (int, error) {
|
||||
return sb.buf.Read(buf)
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) Close() error {
|
||||
sb.recvM.Lock()
|
||||
defer sb.recvM.Unlock()
|
||||
|
||||
return sb.buf.Close()
|
||||
}
|
||||
|
||||
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
//"log"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecvNewFrame(t *testing.T) {
|
||||
inOrder := []uint64{5, 6, 7, 8, 9, 10, 11}
|
||||
outOfOrder0 := []uint64{5, 7, 8, 6, 11, 10, 9}
|
||||
outOfOrder1 := []uint64{1, 96, 47, 2, 29, 18, 60, 8, 74, 22, 82, 58, 44, 51, 57, 71, 90, 94, 68, 83, 61, 91, 39, 97, 85, 63, 46, 73, 54, 84, 76, 98, 93, 79, 75, 50, 67, 37, 92, 99, 42, 77, 17, 16, 38, 3, 100, 24, 31, 7, 36, 40, 86, 64, 34, 45, 12, 5, 9, 27, 21, 26, 35, 6, 65, 69, 53, 4, 48, 28, 30, 56, 32, 11, 80, 66, 25, 41, 78, 13, 88, 62, 15, 70, 49, 43, 72, 23, 10, 55, 52, 95, 14, 59, 87, 33, 19, 20, 81, 89}
|
||||
outOfOrder2 := []uint64{1<<32 - 5, 1<<32 + 3, 1 << 32, 1<<32 - 3, 1<<32 - 4, 1<<32 + 2, 1<<32 - 2, 1<<32 - 1, 1<<32 + 1}
|
||||
|
||||
test := func(set []uint64, ct *testing.T) {
|
||||
sb := NewStreamBuffer()
|
||||
sb.nextRecvSeq = set[0]
|
||||
for _, n := range set {
|
||||
bu64 := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(bu64, n)
|
||||
sb.Write(&Frame{
|
||||
Seq: n,
|
||||
Payload: bu64,
|
||||
})
|
||||
}
|
||||
|
||||
var sortedResult []uint64
|
||||
for x := 0; x < len(set); x++ {
|
||||
oct := make([]byte, 8)
|
||||
n, err := sb.Read(oct)
|
||||
if n != 8 || err != nil {
|
||||
ct.Error("failed to read from sorted Buf", n, err)
|
||||
return
|
||||
}
|
||||
//log.Print(p)
|
||||
sortedResult = append(sortedResult, binary.BigEndian.Uint64(oct))
|
||||
}
|
||||
targetSorted := make([]uint64, len(set))
|
||||
copy(targetSorted, set)
|
||||
sort.Slice(targetSorted, func(i, j int) bool { return targetSorted[i] < targetSorted[j] })
|
||||
|
||||
for i := range targetSorted {
|
||||
if sortedResult[i] != targetSorted[i] {
|
||||
goto fail
|
||||
}
|
||||
}
|
||||
sb.Close()
|
||||
return
|
||||
fail:
|
||||
ct.Error(
|
||||
"expecting", targetSorted,
|
||||
"got", sortedResult,
|
||||
)
|
||||
}
|
||||
|
||||
t.Run("in order", func(t *testing.T) {
|
||||
test(inOrder, t)
|
||||
})
|
||||
t.Run("out of order0", func(t *testing.T) {
|
||||
test(outOfOrder0, t)
|
||||
})
|
||||
t.Run("out of order1", func(t *testing.T) {
|
||||
test(outOfOrder1, t)
|
||||
})
|
||||
t.Run("out of order wrap", func(t *testing.T) {
|
||||
test(outOfOrder2, t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamBuffer_RecvThenClose(t *testing.T) {
|
||||
const testDataLen = 128
|
||||
sb := NewStreamBuffer()
|
||||
testData := make([]byte, testDataLen)
|
||||
testFrame := Frame{
|
||||
StreamID: 0,
|
||||
Seq: 0,
|
||||
Closing: 0,
|
||||
Payload: testData,
|
||||
}
|
||||
sb.Write(&testFrame)
|
||||
sb.Close()
|
||||
|
||||
readBuf := make([]byte, testDataLen)
|
||||
_, err := io.ReadFull(sb, readBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
|
||||
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The point of a streamBufferedPipe is that Read() will block until data is available
|
||||
type streamBufferedPipe struct {
|
||||
buf *bytes.Buffer
|
||||
|
||||
closed bool
|
||||
rwCond *sync.Cond
|
||||
rDeadline time.Time
|
||||
wtTimeout time.Duration
|
||||
|
||||
timeoutTimer *time.Timer
|
||||
}
|
||||
|
||||
func NewStreamBufferedPipe() *streamBufferedPipe {
|
||||
p := &streamBufferedPipe{
|
||||
rwCond: sync.NewCond(&sync.Mutex{}),
|
||||
buf: new(bytes.Buffer),
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *streamBufferedPipe) Read(target []byte) (int, error) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
for {
|
||||
if p.closed && p.buf.Len() == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
hasRDeadline := !p.rDeadline.IsZero()
|
||||
if hasRDeadline {
|
||||
if time.Until(p.rDeadline) <= 0 {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
}
|
||||
if p.buf.Len() > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if hasRDeadline {
|
||||
p.broadcastAfter(time.Until(p.rDeadline))
|
||||
}
|
||||
p.rwCond.Wait()
|
||||
}
|
||||
n, err := p.buf.Read(target)
|
||||
// err will always be nil because we have already verified that buf.Len() != 0
|
||||
p.rwCond.Broadcast()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (p *streamBufferedPipe) Write(input []byte) (int, error) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
for {
|
||||
if p.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if p.buf.Len() <= recvBufferSizeLimit {
|
||||
// if p.buf gets too large, write() will panic. We don't want this to happen
|
||||
break
|
||||
}
|
||||
p.rwCond.Wait()
|
||||
}
|
||||
n, err := p.buf.Write(input)
|
||||
// err will always be nil
|
||||
p.rwCond.Broadcast()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (p *streamBufferedPipe) Close() error {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
|
||||
p.closed = true
|
||||
p.rwCond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *streamBufferedPipe) SetReadDeadline(t time.Time) {
|
||||
p.rwCond.L.Lock()
|
||||
defer p.rwCond.L.Unlock()
|
||||
|
||||
p.rDeadline = t
|
||||
p.rwCond.Broadcast()
|
||||
}
|
||||
|
||||
func (p *streamBufferedPipe) broadcastAfter(d time.Duration) {
|
||||
if p.timeoutTimer != nil {
|
||||
p.timeoutTimer.Stop()
|
||||
}
|
||||
p.timeoutTimer = time.AfterFunc(d, p.rwCond.Broadcast)
|
||||
}
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const readBlockTime = 500 * time.Millisecond
|
||||
|
||||
func TestPipeRW(t *testing.T) {
|
||||
pipe := NewStreamBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
n, err := pipe.Write(b)
|
||||
assert.NoError(t, err, "simple write")
|
||||
assert.Equal(t, len(b), n, "number of bytes written")
|
||||
|
||||
b2 := make([]byte, len(b))
|
||||
n, err = pipe.Read(b2)
|
||||
assert.NoError(t, err, "simple read")
|
||||
assert.Equal(t, len(b), n, "number of bytes read")
|
||||
|
||||
assert.Equal(t, b, b2)
|
||||
}
|
||||
|
||||
func TestReadBlock(t *testing.T) {
|
||||
pipe := NewStreamBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
go func() {
|
||||
time.Sleep(readBlockTime)
|
||||
pipe.Write(b)
|
||||
}()
|
||||
b2 := make([]byte, len(b))
|
||||
n, err := pipe.Read(b2)
|
||||
assert.NoError(t, err, "blocked read")
|
||||
assert.Equal(t, len(b), n, "number of bytes read after block")
|
||||
|
||||
assert.Equal(t, b, b2)
|
||||
}
|
||||
|
||||
func TestPartialRead(t *testing.T) {
|
||||
pipe := NewStreamBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
pipe.Write(b)
|
||||
b1 := make([]byte, 1)
|
||||
n, err := pipe.Read(b1)
|
||||
assert.NoError(t, err, "partial read of 1")
|
||||
assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
|
||||
|
||||
assert.Equal(t, b[0], b1[0])
|
||||
|
||||
b2 := make([]byte, 2)
|
||||
n, err = pipe.Read(b2)
|
||||
assert.NoError(t, err, "partial read of 2")
|
||||
assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
|
||||
|
||||
assert.Equal(t, b[1:], b2)
|
||||
}
|
||||
|
||||
func TestReadAfterClose(t *testing.T) {
|
||||
pipe := NewStreamBufferedPipe()
|
||||
b := []byte{0x01, 0x02, 0x03}
|
||||
pipe.Write(b)
|
||||
b2 := make([]byte, len(b))
|
||||
pipe.Close()
|
||||
n, err := pipe.Read(b2)
|
||||
assert.NoError(t, err, "simple read")
|
||||
assert.Equal(t, len(b), n, "number of bytes read")
|
||||
|
||||
assert.Equal(t, b, b2)
|
||||
}
|
||||
|
||||
func BenchmarkBufferedPipe_RW(b *testing.B) {
|
||||
const PAYLOAD_LEN = 1000
|
||||
testData := make([]byte, PAYLOAD_LEN)
|
||||
rand.Read(testData)
|
||||
|
||||
pipe := NewStreamBufferedPipe()
|
||||
|
||||
smallBuf := make([]byte, PAYLOAD_LEN-10)
|
||||
go func() {
|
||||
for {
|
||||
pipe.Read(smallBuf)
|
||||
}
|
||||
}()
|
||||
b.SetBytes(int64(len(testData)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pipe.Write(testData)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,388 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/cbeuw/connutil"
|
||||
)
|
||||
|
||||
const payloadLen = 1000
|
||||
|
||||
var emptyKey [32]byte
|
||||
|
||||
func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session {
|
||||
obfuscator, _ := MakeObfuscator(encryptionMethod, key)
|
||||
|
||||
seshConfig := SessionConfig{
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: unordered,
|
||||
}
|
||||
return MakeSession(0, seshConfig)
|
||||
}
|
||||
|
||||
func BenchmarkStream_Write_Ordered(b *testing.B) {
|
||||
hole := connutil.Discard()
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
|
||||
const testDataLen = 65536
|
||||
testData := make([]byte, testDataLen)
|
||||
rand.Read(testData)
|
||||
eMethods := map[string]byte{
|
||||
"plain": EncryptionMethodPlain,
|
||||
"chacha20-poly1305": EncryptionMethodChaha20Poly1305,
|
||||
"aes-256-gcm": EncryptionMethodAES256GCM,
|
||||
"aes-128-gcm": EncryptionMethodAES128GCM,
|
||||
}
|
||||
|
||||
for name, method := range eMethods {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
sesh := setupSesh(false, sessionKey, method)
|
||||
sesh.AddConnection(hole)
|
||||
stream, _ := sesh.OpenStream()
|
||||
b.SetBytes(testDataLen)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
stream.Write(testData)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_Write(t *testing.T) {
|
||||
hole := connutil.Discard()
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
sesh.AddConnection(hole)
|
||||
testData := make([]byte, payloadLen)
|
||||
rand.Read(testData)
|
||||
|
||||
stream, _ := sesh.OpenStream()
|
||||
_, err := stream.Write(testData)
|
||||
if err != nil {
|
||||
t.Error(
|
||||
"For", "stream write",
|
||||
"got", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_WriteSync(t *testing.T) {
|
||||
// Close calls made after write MUST have a higher seq
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
clientSesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
serverSesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
w, r := connutil.AsyncPipe()
|
||||
clientSesh.AddConnection(common.NewTLSConn(w))
|
||||
serverSesh.AddConnection(common.NewTLSConn(r))
|
||||
testData := make([]byte, payloadLen)
|
||||
rand.Read(testData)
|
||||
|
||||
t.Run("test single", func(t *testing.T) {
|
||||
go func() {
|
||||
stream, _ := clientSesh.OpenStream()
|
||||
stream.Write(testData)
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
recvBuf := make([]byte, payloadLen)
|
||||
serverStream, _ := serverSesh.Accept()
|
||||
_, err := io.ReadFull(serverStream, recvBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test multiple", func(t *testing.T) {
|
||||
const numStreams = 100
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func() {
|
||||
stream, _ := clientSesh.OpenStream()
|
||||
stream.Write(testData)
|
||||
stream.Close()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < numStreams; i++ {
|
||||
recvBuf := make([]byte, payloadLen)
|
||||
serverStream, _ := serverSesh.Accept()
|
||||
_, err := io.ReadFull(serverStream, recvBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_Close(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
testPayload := []byte{42, 42, 42}
|
||||
|
||||
dataFrame := &Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
t.Run("active closing", func(t *testing.T) {
|
||||
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
rawConn, rawWritingEnd := connutil.AsyncPipe()
|
||||
sesh.AddConnection(common.NewTLSConn(rawConn))
|
||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||
|
||||
obfsBuf := make([]byte, 512)
|
||||
i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0)
|
||||
_, err := writingEnd.Write(obfsBuf[:i])
|
||||
if err != nil {
|
||||
t.Error("failed to write from remote end")
|
||||
}
|
||||
stream, err := sesh.Accept()
|
||||
if err != nil {
|
||||
t.Error("failed to accept stream", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
err = stream.Close()
|
||||
if err != nil {
|
||||
t.Error("failed to actively close stream", err)
|
||||
return
|
||||
}
|
||||
|
||||
sesh.streamsM.Lock()
|
||||
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
|
||||
sesh.streamsM.Unlock()
|
||||
t.Error("stream still exists")
|
||||
return
|
||||
}
|
||||
sesh.streamsM.Unlock()
|
||||
|
||||
readBuf := make([]byte, len(testPayload))
|
||||
_, err = io.ReadFull(stream, readBuf)
|
||||
if err != nil {
|
||||
t.Errorf("cannot read resiual data: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf, testPayload) {
|
||||
t.Errorf("read wrong data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("passive closing", func(t *testing.T) {
|
||||
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
rawConn, rawWritingEnd := connutil.AsyncPipe()
|
||||
sesh.AddConnection(common.NewTLSConn(rawConn))
|
||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||
|
||||
obfsBuf := make([]byte, 512)
|
||||
i, err := sesh.obfuscate(dataFrame, obfsBuf, 0)
|
||||
if err != nil {
|
||||
t.Errorf("failed to obfuscate frame %v", err)
|
||||
}
|
||||
_, err = writingEnd.Write(obfsBuf[:i])
|
||||
if err != nil {
|
||||
t.Error("failed to write from remote end")
|
||||
}
|
||||
|
||||
stream, err := sesh.Accept()
|
||||
if err != nil {
|
||||
t.Error("failed to accept stream", err)
|
||||
return
|
||||
}
|
||||
|
||||
closingFrame := &Frame{
|
||||
1,
|
||||
dataFrame.Seq + 1,
|
||||
closingStream,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
i, err = sesh.obfuscate(closingFrame, obfsBuf, 0)
|
||||
if err != nil {
|
||||
t.Errorf("failed to obfuscate frame %v", err)
|
||||
}
|
||||
_, err = writingEnd.Write(obfsBuf[:i])
|
||||
if err != nil {
|
||||
t.Errorf("failed to write from remote end %v", err)
|
||||
}
|
||||
|
||||
closingFrameDup := &Frame{
|
||||
1,
|
||||
dataFrame.Seq + 2,
|
||||
closingStream,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0)
|
||||
if err != nil {
|
||||
t.Errorf("failed to obfuscate frame %v", err)
|
||||
}
|
||||
_, err = writingEnd.Write(obfsBuf[:i])
|
||||
if err != nil {
|
||||
t.Errorf("failed to write from remote end %v", err)
|
||||
}
|
||||
|
||||
readBuf := make([]byte, len(testPayload))
|
||||
_, err = io.ReadFull(stream, readBuf)
|
||||
if err != nil {
|
||||
t.Errorf("can't read residual data %v", err)
|
||||
}
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
sesh.streamsM.Lock()
|
||||
s, _ := sesh.streams[stream.(*Stream).id]
|
||||
sesh.streamsM.Unlock()
|
||||
return s == nil
|
||||
}, time.Second, 10*time.Millisecond, "streams still exists")
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_Read(t *testing.T) {
|
||||
seshes := map[string]bool{
|
||||
"ordered": false,
|
||||
"unordered": true,
|
||||
}
|
||||
testPayload := []byte{42, 42, 42}
|
||||
const smallPayloadLen = 3
|
||||
|
||||
f := &Frame{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
testPayload,
|
||||
}
|
||||
|
||||
var streamID uint32
|
||||
|
||||
for name, unordered := range seshes {
|
||||
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
|
||||
rawConn, rawWritingEnd := connutil.AsyncPipe()
|
||||
sesh.AddConnection(common.NewTLSConn(rawConn))
|
||||
writingEnd := common.NewTLSConn(rawWritingEnd)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
buf := make([]byte, 10)
|
||||
obfsBuf := make([]byte, 512)
|
||||
t.Run("Plain read", func(t *testing.T) {
|
||||
f.StreamID = streamID
|
||||
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||
streamID++
|
||||
writingEnd.Write(obfsBuf[:i])
|
||||
stream, err := sesh.Accept()
|
||||
if err != nil {
|
||||
t.Error("failed to accept stream", err)
|
||||
return
|
||||
}
|
||||
i, err = stream.Read(buf)
|
||||
if err != nil {
|
||||
t.Error("failed to read", err)
|
||||
return
|
||||
}
|
||||
if i != smallPayloadLen {
|
||||
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(buf[:i], testPayload) {
|
||||
t.Error("expected", testPayload,
|
||||
"got", buf[:i])
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("Nil buf", func(t *testing.T) {
|
||||
f.StreamID = streamID
|
||||
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||
streamID++
|
||||
writingEnd.Write(obfsBuf[:i])
|
||||
stream, _ := sesh.Accept()
|
||||
i, err := stream.Read(nil)
|
||||
if i != 0 || err != nil {
|
||||
t.Error("expecting", 0, nil,
|
||||
"got", i, err)
|
||||
}
|
||||
})
|
||||
t.Run("Read after stream close", func(t *testing.T) {
|
||||
f.StreamID = streamID
|
||||
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||
streamID++
|
||||
writingEnd.Write(obfsBuf[:i])
|
||||
stream, _ := sesh.Accept()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
stream.Close()
|
||||
|
||||
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
|
||||
if err != nil {
|
||||
t.Errorf("cannot read residual data: %v", err)
|
||||
}
|
||||
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
|
||||
t.Error("expected", testPayload,
|
||||
"got", buf[:smallPayloadLen])
|
||||
}
|
||||
_, err = stream.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("expecting error", ErrBrokenStream,
|
||||
"got nil error")
|
||||
}
|
||||
})
|
||||
t.Run("Read after session close", func(t *testing.T) {
|
||||
f.StreamID = streamID
|
||||
i, _ := sesh.obfuscate(f, obfsBuf, 0)
|
||||
streamID++
|
||||
writingEnd.Write(obfsBuf[:i])
|
||||
stream, _ := sesh.Accept()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
sesh.Close()
|
||||
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
|
||||
if err != nil {
|
||||
t.Errorf("cannot read resiual data: %v", err)
|
||||
}
|
||||
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
|
||||
t.Error("expected", testPayload,
|
||||
"got", buf[:smallPayloadLen])
|
||||
}
|
||||
_, err = stream.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("expecting error", ErrBrokenStream,
|
||||
"got nil error")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_SetReadFromTimeout(t *testing.T) {
|
||||
seshes := map[string]*Session{
|
||||
"ordered": setupSesh(false, emptyKey, EncryptionMethodPlain),
|
||||
"unordered": setupSesh(true, emptyKey, EncryptionMethodPlain),
|
||||
}
|
||||
for name, sesh := range seshes {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
stream, _ := sesh.OpenStream()
|
||||
stream.SetReadFromTimeout(100 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
stream.ReadFrom(connutil.Discard())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("didn't timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -2,165 +2,162 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"math/rand/v2"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type switchboardStrategy int
|
||||
|
||||
const (
|
||||
fixedConnMapping switchboardStrategy = iota
|
||||
uniformSpread
|
||||
)
|
||||
|
||||
// switchboard represents the connection pool. It is responsible for managing
|
||||
// transport-layer connections between client and server.
|
||||
// It has several purposes: constantly receiving incoming data from all connections
|
||||
// and pass them to Session.recvDataFromRemote(); accepting data through
|
||||
// switchboard.send(), in which it selects a connection according to its
|
||||
// switchboardStrategy and send the data off using that; and counting, as well as
|
||||
// rate limiting, data received and sent through its Valve.
|
||||
// switchboard is responsible for keeping the reference of TLS connections between client and server
|
||||
type switchboard struct {
|
||||
session *Session
|
||||
|
||||
valve Valve
|
||||
strategy switchboardStrategy
|
||||
*Valve
|
||||
|
||||
conns sync.Map
|
||||
connsCount uint32
|
||||
randPool sync.Pool
|
||||
|
||||
broken uint32
|
||||
// optimum is the connEnclave with the smallest sendQueue
|
||||
optimum atomic.Value // *connEnclave
|
||||
cesM sync.RWMutex
|
||||
ces []*connEnclave
|
||||
}
|
||||
|
||||
func makeSwitchboard(sesh *Session) *switchboard {
|
||||
func (sb *switchboard) getOptimum() *connEnclave {
|
||||
if i := sb.optimum.Load(); i == nil {
|
||||
return nil
|
||||
} else {
|
||||
return i.(*connEnclave)
|
||||
}
|
||||
}
|
||||
|
||||
// Some data comes from a Stream to be sent through one of the many
|
||||
// remoteConn, but which remoteConn should we use to send the data?
|
||||
//
|
||||
// In this case, we pick the remoteConn that has about the smallest sendQueue.
|
||||
type connEnclave struct {
|
||||
remoteConn net.Conn
|
||||
sendQueue uint32
|
||||
}
|
||||
|
||||
func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
|
||||
// rates are uint64 because in the usermanager we want the bandwidth to be atomically
|
||||
// operated (so that the bandwidth can change on the fly).
|
||||
sb := &switchboard{
|
||||
session: sesh,
|
||||
strategy: uniformSpread,
|
||||
valve: sesh.Valve,
|
||||
randPool: sync.Pool{New: func() interface{} {
|
||||
var state [32]byte
|
||||
common.CryptoRandRead(state[:])
|
||||
return rand.New(rand.NewChaCha8(state))
|
||||
}},
|
||||
session: sesh,
|
||||
Valve: valve,
|
||||
ces: []*connEnclave{},
|
||||
}
|
||||
return sb
|
||||
}
|
||||
|
||||
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
||||
var errNilOptimum error = errors.New("The optimal connection is nil")
|
||||
|
||||
func (sb *switchboard) addConn(conn net.Conn) {
|
||||
connId := atomic.AddUint32(&sb.connsCount, 1) - 1
|
||||
sb.conns.Store(connId, conn)
|
||||
go sb.deplex(conn)
|
||||
}
|
||||
var ErrNoRxCredit error = errors.New("No Rx credit is left")
|
||||
var ErrNoTxCredit error = errors.New("No Tx credit is left")
|
||||
|
||||
// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable
|
||||
func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) {
|
||||
sb.valve.txWait(len(data))
|
||||
if atomic.LoadUint32(&sb.broken) == 1 {
|
||||
return 0, errBrokenSwitchboard
|
||||
func (sb *switchboard) send(data []byte) (int, error) {
|
||||
ce := sb.getOptimum()
|
||||
if ce == nil {
|
||||
return 0, errNilOptimum
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
switch sb.strategy {
|
||||
case uniformSpread:
|
||||
conn, err = sb.pickRandConn()
|
||||
if err != nil {
|
||||
return 0, errBrokenSwitchboard
|
||||
}
|
||||
n, err = conn.Write(data)
|
||||
if err != nil {
|
||||
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
|
||||
sb.session.passiveClose()
|
||||
return n, err
|
||||
}
|
||||
case fixedConnMapping:
|
||||
// FIXME: this strategy has a tendency to cause a TLS conn socket buffer to fill up,
|
||||
// which is a problem when multiple streams are mapped to the same conn, resulting
|
||||
// in all such streams being blocked.
|
||||
conn = *assignedConn
|
||||
if conn == nil {
|
||||
conn, err = sb.pickRandConn()
|
||||
if err != nil {
|
||||
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
|
||||
sb.session.passiveClose()
|
||||
return 0, err
|
||||
}
|
||||
*assignedConn = conn
|
||||
}
|
||||
n, err = conn.Write(data)
|
||||
if err != nil {
|
||||
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
|
||||
sb.session.passiveClose()
|
||||
return n, err
|
||||
}
|
||||
default:
|
||||
return 0, errors.New("unsupported traffic distribution strategy")
|
||||
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
|
||||
go sb.updateOptimum()
|
||||
n, err := ce.remoteConn.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
sb.valve.AddTx(int64(n))
|
||||
sb.txWait(n)
|
||||
if sb.AddTxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoTxCredit)
|
||||
go sb.session.Close()
|
||||
return n, ErrNoTxCredit
|
||||
}
|
||||
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
|
||||
go sb.updateOptimum()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// returns a random conn. This function can be called concurrently.
|
||||
func (sb *switchboard) pickRandConn() (net.Conn, error) {
|
||||
if atomic.LoadUint32(&sb.broken) == 1 {
|
||||
return nil, errBrokenSwitchboard
|
||||
func (sb *switchboard) updateOptimum() {
|
||||
currentOpti := sb.getOptimum()
|
||||
currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue)
|
||||
sb.cesM.RLock()
|
||||
for _, ce := range sb.ces {
|
||||
ceQ := atomic.LoadUint32(&ce.sendQueue)
|
||||
if ceQ < currentOptiQ {
|
||||
currentOpti = ce
|
||||
currentOptiQ = ceQ
|
||||
}
|
||||
}
|
||||
sb.cesM.RUnlock()
|
||||
sb.optimum.Store(currentOpti)
|
||||
}
|
||||
|
||||
connsCount := atomic.LoadUint32(&sb.connsCount)
|
||||
if connsCount == 0 {
|
||||
return nil, errBrokenSwitchboard
|
||||
func (sb *switchboard) addConn(conn net.Conn) {
|
||||
var sendQueue uint32
|
||||
newCe := &connEnclave{
|
||||
remoteConn: conn,
|
||||
sendQueue: sendQueue,
|
||||
}
|
||||
sb.cesM.Lock()
|
||||
sb.ces = append(sb.ces, newCe)
|
||||
sb.cesM.Unlock()
|
||||
sb.optimum.Store(newCe)
|
||||
go sb.deplex(newCe)
|
||||
}
|
||||
|
||||
randReader := sb.randPool.Get().(*rand.Rand)
|
||||
connId := randReader.Uint32N(connsCount)
|
||||
sb.randPool.Put(randReader)
|
||||
|
||||
ret, ok := sb.conns.Load(connId)
|
||||
if !ok {
|
||||
log.Errorf("failed to get conn %d", connId)
|
||||
return nil, errBrokenSwitchboard
|
||||
func (sb *switchboard) removeConn(closing *connEnclave) {
|
||||
sb.cesM.Lock()
|
||||
for i, ce := range sb.ces {
|
||||
if closing == ce {
|
||||
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
return ret.(net.Conn), nil
|
||||
if len(sb.ces) == 0 {
|
||||
sb.cesM.Unlock()
|
||||
sb.session.Close()
|
||||
return
|
||||
}
|
||||
sb.cesM.Unlock()
|
||||
}
|
||||
|
||||
// actively triggered by session.Close()
|
||||
func (sb *switchboard) closeAll() {
|
||||
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
||||
return
|
||||
sb.cesM.RLock()
|
||||
for _, ce := range sb.ces {
|
||||
ce.remoteConn.Close()
|
||||
}
|
||||
atomic.StoreUint32(&sb.connsCount, 0)
|
||||
sb.conns.Range(func(_, conn interface{}) bool {
|
||||
conn.(net.Conn).Close()
|
||||
sb.conns.Delete(conn)
|
||||
return true
|
||||
})
|
||||
sb.cesM.RUnlock()
|
||||
}
|
||||
|
||||
// deplex function costantly reads from a TCP connection
|
||||
func (sb *switchboard) deplex(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, sb.session.connReceiveBufferSize)
|
||||
// deplex function costantly reads from a TCP connection, call deobfs and distribute it
|
||||
// to the corresponding stream
|
||||
func (sb *switchboard) deplex(ce *connEnclave) {
|
||||
buf := make([]byte, 20480)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
sb.valve.rxWait(n)
|
||||
sb.valve.AddRx(int64(n))
|
||||
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
|
||||
if err != nil {
|
||||
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
|
||||
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
|
||||
sb.session.passiveClose()
|
||||
//log.Println(err)
|
||||
go ce.remoteConn.Close()
|
||||
sb.removeConn(ce)
|
||||
return
|
||||
}
|
||||
|
||||
err = sb.session.recvDataFromRemote(buf[:n])
|
||||
sb.rxWait(n)
|
||||
if sb.AddRxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoRxCredit)
|
||||
sb.session.Close()
|
||||
return
|
||||
}
|
||||
frame, err := sb.session.deobfs(buf[:n])
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
|
||||
stream := sb.session.getStream(frame.StreamID, frame.Closing == 1)
|
||||
// if the frame is telling us to close a closed stream
|
||||
// (this happens when ss-server and ss-local closes the stream
|
||||
// simutaneously), we don't do anything
|
||||
if stream != nil {
|
||||
stream.writeNewFrame(frame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,187 +0,0 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/connutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSwitchboard_Send(t *testing.T) {
|
||||
doTest := func(seshConfig SessionConfig) {
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
hole0 := connutil.Discard()
|
||||
sesh.sb.addConn(hole0)
|
||||
conn, err := sesh.sb.pickRandConn()
|
||||
if err != nil {
|
||||
t.Error("failed to get a random conn", err)
|
||||
return
|
||||
}
|
||||
data := make([]byte, 1000)
|
||||
rand.Read(data)
|
||||
_, err = sesh.sb.send(data, &conn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
hole1 := connutil.Discard()
|
||||
sesh.sb.addConn(hole1)
|
||||
conn, err = sesh.sb.pickRandConn()
|
||||
if err != nil {
|
||||
t.Error("failed to get a random conn", err)
|
||||
return
|
||||
}
|
||||
_, err = sesh.sb.send(data, &conn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err = sesh.sb.pickRandConn()
|
||||
if err != nil {
|
||||
t.Error("failed to get a random conn", err)
|
||||
return
|
||||
}
|
||||
_, err = sesh.sb.send(data, &conn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Ordered", func(t *testing.T) {
|
||||
seshConfig := SessionConfig{
|
||||
Unordered: false,
|
||||
}
|
||||
doTest(seshConfig)
|
||||
})
|
||||
t.Run("Unordered", func(t *testing.T) {
|
||||
seshConfig := SessionConfig{
|
||||
Unordered: true,
|
||||
}
|
||||
doTest(seshConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSwitchboard_Send(b *testing.B) {
|
||||
hole := connutil.Discard()
|
||||
seshConfig := SessionConfig{}
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
sesh.sb.addConn(hole)
|
||||
conn, err := sesh.sb.pickRandConn()
|
||||
if err != nil {
|
||||
b.Error("failed to get a random conn", err)
|
||||
return
|
||||
}
|
||||
data := make([]byte, 1000)
|
||||
rand.Read(data)
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sesh.sb.send(data, &conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchboard_TxCredit(t *testing.T) {
|
||||
seshConfig := SessionConfig{
|
||||
Valve: MakeValve(1<<20, 1<<20),
|
||||
}
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
hole := connutil.Discard()
|
||||
sesh.sb.addConn(hole)
|
||||
conn, err := sesh.sb.pickRandConn()
|
||||
if err != nil {
|
||||
t.Error("failed to get a random conn", err)
|
||||
return
|
||||
}
|
||||
data := make([]byte, 1000)
|
||||
rand.Read(data)
|
||||
|
||||
t.Run("fixed conn mapping", func(t *testing.T) {
|
||||
*sesh.sb.valve.(*LimitedValve).tx = 0
|
||||
sesh.sb.strategy = fixedConnMapping
|
||||
n, err := sesh.sb.send(data[:10], &conn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if n != 10 {
|
||||
t.Errorf("wanted to send %v, got %v", 10, n)
|
||||
return
|
||||
}
|
||||
if *sesh.sb.valve.(*LimitedValve).tx != 10 {
|
||||
t.Error("tx credit didn't increase by 10")
|
||||
}
|
||||
})
|
||||
t.Run("uniform spread", func(t *testing.T) {
|
||||
*sesh.sb.valve.(*LimitedValve).tx = 0
|
||||
sesh.sb.strategy = uniformSpread
|
||||
n, err := sesh.sb.send(data[:10], &conn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if n != 10 {
|
||||
t.Errorf("wanted to send %v, got %v", 10, n)
|
||||
return
|
||||
}
|
||||
if *sesh.sb.valve.(*LimitedValve).tx != 10 {
|
||||
t.Error("tx credit didn't increase by 10")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSwitchboard_CloseOnOneDisconn(t *testing.T) {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
|
||||
|
||||
conn0client, conn0server := connutil.AsyncPipe()
|
||||
sesh.AddConnection(conn0client)
|
||||
|
||||
conn1client, _ := connutil.AsyncPipe()
|
||||
sesh.AddConnection(conn1client)
|
||||
|
||||
conn0server.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return sesh.IsClosed()
|
||||
}, time.Second, 10*time.Millisecond, "session not closed after one conn is disconnected")
|
||||
|
||||
if _, err := conn1client.Write([]byte{0x00}); err == nil {
|
||||
t.Error("the other conn is still connected")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchboard_ConnsCount(t *testing.T) {
|
||||
seshConfig := SessionConfig{
|
||||
Valve: MakeValve(1<<20, 1<<20),
|
||||
}
|
||||
sesh := MakeSession(0, seshConfig)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
sesh.AddConnection(connutil.Discard())
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
|
||||
t.Error("connsCount incorrect")
|
||||
}
|
||||
|
||||
sesh.sb.closeAll()
|
||||
|
||||
assert.Eventuallyf(t, func() bool {
|
||||
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
|
||||
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
|
||||
}
|
||||
|
|
@ -1,101 +1,166 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
const appDataMaxLength = 16401
|
||||
// ClientHello contains every field in a ClientHello message
|
||||
type ClientHello struct {
|
||||
handshakeType byte
|
||||
length int
|
||||
clientVersion []byte
|
||||
random []byte
|
||||
sessionIdLen int
|
||||
sessionId []byte
|
||||
cipherSuitesLen int
|
||||
cipherSuites []byte
|
||||
compressionMethodsLen int
|
||||
compressionMethods []byte
|
||||
extensionsLen int
|
||||
extensions map[[2]byte][]byte
|
||||
}
|
||||
|
||||
type TLS struct{}
|
||||
var u16 = binary.BigEndian.Uint16
|
||||
var u32 = binary.BigEndian.Uint32
|
||||
|
||||
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
|
||||
|
||||
func (TLS) String() string { return "TLS" }
|
||||
|
||||
func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
|
||||
ch, err := parseClientHello(clientHello)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
err = ErrBadClientHello
|
||||
return
|
||||
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed Extensions")
|
||||
}
|
||||
}()
|
||||
pointer := 0
|
||||
totalLen := len(input)
|
||||
ret = make(map[[2]byte][]byte)
|
||||
for pointer < totalLen {
|
||||
var typ [2]byte
|
||||
copy(typ[:], input[pointer:pointer+2])
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
data := input[pointer : pointer+length]
|
||||
pointer += length
|
||||
ret[typ] = data
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
|
||||
fragments, err = TLS{}.unmarshalClientHello(ch, privateKey)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal ClientHello into authFragments: %v", err)
|
||||
return
|
||||
// AddRecordLayer adds record layer to data
|
||||
func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
||||
ret := make([]byte, 5+len(input))
|
||||
copy(ret[0:1], typ)
|
||||
copy(ret[1:3], ver)
|
||||
copy(ret[3:5], length)
|
||||
copy(ret[5:], input)
|
||||
return ret
|
||||
}
|
||||
|
||||
// PeelRecordLayer peels off the record layer
|
||||
func PeelRecordLayer(data []byte) []byte {
|
||||
ret := data[5:]
|
||||
return ret
|
||||
}
|
||||
|
||||
// ParseClientHello parses everything on top of the TLS layer
|
||||
// (including the record layer) into ClientHello type
|
||||
func ParseClientHello(data []byte) (ret *ClientHello, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed ClientHello")
|
||||
}
|
||||
}()
|
||||
data = PeelRecordLayer(data)
|
||||
pointer := 0
|
||||
// Handshake Type
|
||||
handshakeType := data[pointer]
|
||||
if handshakeType != 0x01 {
|
||||
return ret, errors.New("Not a ClientHello")
|
||||
}
|
||||
pointer += 1
|
||||
// Length
|
||||
length := int(u32(append([]byte{0x00}, data[pointer:pointer+3]...)))
|
||||
pointer += 3
|
||||
if length != len(data[pointer:]) {
|
||||
return ret, errors.New("Hello length doesn't match")
|
||||
}
|
||||
// Client Version
|
||||
clientVersion := data[pointer : pointer+2]
|
||||
pointer += 2
|
||||
// Random
|
||||
random := data[pointer : pointer+32]
|
||||
pointer += 32
|
||||
// Session ID
|
||||
sessionIdLen := int(data[pointer])
|
||||
pointer += 1
|
||||
sessionId := data[pointer : pointer+sessionIdLen]
|
||||
pointer += sessionIdLen
|
||||
// Cipher Suites
|
||||
cipherSuitesLen := int(u16(data[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
cipherSuites := data[pointer : pointer+cipherSuitesLen]
|
||||
pointer += cipherSuitesLen
|
||||
// Compression Methods
|
||||
compressionMethodsLen := int(data[pointer])
|
||||
pointer += 1
|
||||
compressionMethods := data[pointer : pointer+compressionMethodsLen]
|
||||
pointer += compressionMethodsLen
|
||||
// Extensions
|
||||
extensionsLen := int(u16(data[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
extensions, err := parseExtensions(data[pointer:])
|
||||
ret = &ClientHello{
|
||||
handshakeType,
|
||||
length,
|
||||
clientVersion,
|
||||
random,
|
||||
sessionIdLen,
|
||||
sessionId,
|
||||
cipherSuitesLen,
|
||||
cipherSuites,
|
||||
compressionMethodsLen,
|
||||
compressionMethods,
|
||||
extensionsLen,
|
||||
extensions,
|
||||
}
|
||||
|
||||
respond = TLS{}.makeResponder(ch.sessionId, fragments.sharedSecret)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Responder {
|
||||
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) {
|
||||
// the cert length needs to be the same for all handshakes belonging to the same session
|
||||
// we can use sessionKey as a seed here to ensure consistency
|
||||
possibleCertLengths := []int{42, 27, 68, 59, 36, 44, 46}
|
||||
cert := make([]byte, possibleCertLengths[common.RandInt(len(possibleCertLengths))])
|
||||
common.RandRead(randSource, cert)
|
||||
|
||||
var nonce [12]byte
|
||||
common.RandRead(randSource, nonce[:])
|
||||
encryptedSessionKey, err := common.AESGCMEncrypt(nonce[:], sharedSecret[:], sessionKey[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var encryptedSessionKeyArr [48]byte
|
||||
copy(encryptedSessionKeyArr[:], encryptedSessionKey)
|
||||
|
||||
reply := composeReply(clientHelloSessionId, nonce, encryptedSessionKeyArr, cert)
|
||||
_, err = originalConn.Write(reply)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to write TLS reply: %v", err)
|
||||
originalConn.Close()
|
||||
return
|
||||
}
|
||||
preparedConn = common.NewTLSConn(originalConn)
|
||||
return
|
||||
func composeServerHello(ch *ClientHello) []byte {
|
||||
var serverHello [10][]byte
|
||||
serverHello[0] = []byte{0x02} // handshake type
|
||||
serverHello[1] = []byte{0x00, 0x00, 0x4d} // length 77
|
||||
serverHello[2] = []byte{0x03, 0x03} // server version
|
||||
serverHello[3] = util.PsudoRandBytes(32, time.Now().UnixNano()) // random
|
||||
serverHello[4] = []byte{0x20} // session id length 32
|
||||
serverHello[5] = ch.sessionId // session id
|
||||
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
serverHello[7] = []byte{0x00} // compression method null
|
||||
serverHello[8] = []byte{0x00, 0x05} // extensions length 5
|
||||
serverHello[9] = []byte{0xff, 0x01, 0x00, 0x01, 0x00} // extensions renegotiation_info
|
||||
ret := []byte{}
|
||||
for i := 0; i < 10; i++ {
|
||||
ret = append(ret, serverHello[i]...)
|
||||
}
|
||||
return respond
|
||||
return ret
|
||||
}
|
||||
|
||||
func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fragments authFragments, err error) {
|
||||
copy(fragments.randPubKey[:], ch.random)
|
||||
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:])
|
||||
if !ok {
|
||||
err = ErrInvalidPubKey
|
||||
return
|
||||
}
|
||||
|
||||
var sharedSecret []byte
|
||||
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
copy(fragments.sharedSecret[:], sharedSecret)
|
||||
var keyShare []byte
|
||||
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctxTag := append(ch.sessionId, keyShare...)
|
||||
if len(ctxTag) != 64 {
|
||||
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(ctxTag))
|
||||
return
|
||||
}
|
||||
copy(fragments.ciphertextWithTag[:], ctxTag)
|
||||
return
|
||||
// ComposeReply composes the ServerHello, ChangeCipherSpec and Finished messages
|
||||
// together with their respective record layers into one byte slice. The content
|
||||
// of these messages are random and useless for this plugin
|
||||
func ComposeReply(ch *ClientHello) []byte {
|
||||
TLS12 := []byte{0x03, 0x03}
|
||||
shBytes := AddRecordLayer(composeServerHello(ch), []byte{0x16}, TLS12)
|
||||
ccsBytes := AddRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
||||
finished := make([]byte, 64)
|
||||
finished = util.PsudoRandBytes(40, time.Now().UnixNano())
|
||||
fBytes := AddRecordLayer(finished, []byte{0x16}, TLS12)
|
||||
ret := append(shBytes, ccsBytes...)
|
||||
ret = append(ret, fBytes...)
|
||||
return ret
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,202 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
)
|
||||
|
||||
// ClientHello contains every field in a ClientHello message
|
||||
type ClientHello struct {
|
||||
handshakeType byte
|
||||
length int
|
||||
clientVersion []byte
|
||||
random []byte
|
||||
sessionIdLen int
|
||||
sessionId []byte
|
||||
cipherSuitesLen int
|
||||
cipherSuites []byte
|
||||
compressionMethodsLen int
|
||||
compressionMethods []byte
|
||||
extensionsLen int
|
||||
extensions map[[2]byte][]byte
|
||||
}
|
||||
|
||||
var u16 = binary.BigEndian.Uint16
|
||||
var u32 = binary.BigEndian.Uint32
|
||||
|
||||
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed Extensions")
|
||||
}
|
||||
}()
|
||||
pointer := 0
|
||||
totalLen := len(input)
|
||||
ret = make(map[[2]byte][]byte)
|
||||
for pointer < totalLen {
|
||||
var typ [2]byte
|
||||
copy(typ[:], input[pointer:pointer+2])
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
data := input[pointer : pointer+length]
|
||||
pointer += length
|
||||
ret[typ] = data
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func parseKeyShare(input []byte) (ret []byte, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("malformed key_share")
|
||||
}
|
||||
}()
|
||||
totalLen := int(u16(input[0:2]))
|
||||
// 2 bytes "client key share length"
|
||||
pointer := 2
|
||||
for pointer < totalLen {
|
||||
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
|
||||
// skip "key exchange length"
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
if length != 32 {
|
||||
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
|
||||
}
|
||||
return input[pointer : pointer+length], nil
|
||||
}
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
_ = input[pointer : pointer+length]
|
||||
pointer += length
|
||||
}
|
||||
return nil, errors.New("x25519 does not exist")
|
||||
}
|
||||
|
||||
// addRecordLayer adds record layer to data
|
||||
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
||||
ret := make([]byte, 5+len(input))
|
||||
copy(ret[0:1], typ)
|
||||
copy(ret[1:3], ver)
|
||||
copy(ret[3:5], length)
|
||||
copy(ret[5:], input)
|
||||
return ret
|
||||
}
|
||||
|
||||
// parseClientHello parses everything on top of the TLS layer
|
||||
// (including the record layer) into ClientHello type
|
||||
func parseClientHello(data []byte) (ret *ClientHello, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed ClientHello")
|
||||
}
|
||||
}()
|
||||
|
||||
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
|
||||
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
|
||||
}
|
||||
|
||||
peeled := make([]byte, len(data)-5)
|
||||
copy(peeled, data[5:])
|
||||
pointer := 0
|
||||
// Handshake Type
|
||||
handshakeType := peeled[pointer]
|
||||
if handshakeType != 0x01 {
|
||||
return ret, errors.New("Not a ClientHello")
|
||||
}
|
||||
pointer += 1
|
||||
// Length
|
||||
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
|
||||
pointer += 3
|
||||
if length != len(peeled[pointer:]) {
|
||||
return ret, errors.New("Hello length doesn't match")
|
||||
}
|
||||
// Client Version
|
||||
clientVersion := peeled[pointer : pointer+2]
|
||||
pointer += 2
|
||||
// Random
|
||||
random := peeled[pointer : pointer+32]
|
||||
pointer += 32
|
||||
// Session ID
|
||||
sessionIdLen := int(peeled[pointer])
|
||||
pointer += 1
|
||||
sessionId := peeled[pointer : pointer+sessionIdLen]
|
||||
pointer += sessionIdLen
|
||||
// Cipher Suites
|
||||
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
|
||||
pointer += cipherSuitesLen
|
||||
// Compression Methods
|
||||
compressionMethodsLen := int(peeled[pointer])
|
||||
pointer += 1
|
||||
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
|
||||
pointer += compressionMethodsLen
|
||||
// Extensions
|
||||
extensionsLen := int(u16(peeled[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
extensions, err := parseExtensions(peeled[pointer:])
|
||||
ret = &ClientHello{
|
||||
handshakeType,
|
||||
length,
|
||||
clientVersion,
|
||||
random,
|
||||
sessionIdLen,
|
||||
sessionId,
|
||||
cipherSuitesLen,
|
||||
cipherSuites,
|
||||
compressionMethodsLen,
|
||||
compressionMethods,
|
||||
extensionsLen,
|
||||
extensions,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func composeServerHello(sessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte) []byte {
|
||||
var serverHello [11][]byte
|
||||
serverHello[0] = []byte{0x02} // handshake type
|
||||
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 118
|
||||
serverHello[2] = []byte{0x03, 0x03} // server version
|
||||
serverHello[3] = append(nonce[0:12], encryptedSessionKeyWithTag[0:20]...) // random 32 bytes
|
||||
serverHello[4] = []byte{0x20} // session id length 32
|
||||
serverHello[5] = sessionId // session id
|
||||
serverHello[6] = []byte{0x13, 0x02} // cipher suite TLS_AES_256_GCM_SHA384
|
||||
serverHello[7] = []byte{0x00} // compression method null
|
||||
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
|
||||
|
||||
keyShare := []byte{0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20}
|
||||
keyExchange := make([]byte, 32)
|
||||
copy(keyExchange, encryptedSessionKeyWithTag[20:48])
|
||||
common.CryptoRandRead(keyExchange[28:32])
|
||||
serverHello[9] = append(keyShare, keyExchange...)
|
||||
|
||||
serverHello[10] = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} // supported versions
|
||||
var ret []byte
|
||||
for _, s := range serverHello {
|
||||
ret = append(ret, s...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
|
||||
// together with their respective record layers into one byte slice.
|
||||
func composeReply(clientHelloSessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte, cert []byte) []byte {
|
||||
TLS12 := []byte{0x03, 0x03}
|
||||
sh := composeServerHello(clientHelloSessionId, nonce, encryptedSessionKeyWithTag)
|
||||
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
|
||||
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
||||
|
||||
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
|
||||
ret := append(shBytes, ccsBytes...)
|
||||
ret = append(ret, encryptedCertBytes...)
|
||||
return ret
|
||||
}
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseClientHello(t *testing.T) {
|
||||
t.Run("good Cloak ClientHello", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, err := parseClientHello(chBytes)
|
||||
if err != nil {
|
||||
t.Errorf("Expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(ch.clientVersion, []byte{0x03, 0x03}) {
|
||||
t.Errorf("expecting client version 0x0303, got %v", ch.clientVersion)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("Malformed ClientHello", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fb2f21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
_, err := parseClientHello(chBytes)
|
||||
if err == nil {
|
||||
t.Error("expecting Malformed ClientHello, got no error")
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("not Handshake", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("ff03010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
_, err := parseClientHello(chBytes)
|
||||
if err == nil {
|
||||
t.Error("not a tls handshake, got no error")
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("wrong TLS record layer version", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("16ff010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
_, err := parseClientHello(chBytes)
|
||||
if err == nil {
|
||||
t.Error("wrong version, got no error")
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("TLS 1.2", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("16030300bd010000b903035d5741ed86719917a932db1dc59a22c7166bf90f5bd693564341d091ffbac5db00002ac02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c013009d009c003d003c0035002f000a0100006600000022002000001d6e61762e736d61727473637265656e2e6d6963726f736f66742e636f6d000500050100000000000a00080006001d00170018000b00020100000d001400120401050102010403050302030202060106030023000000170000ff01000100")
|
||||
_, err := parseClientHello(chBytes)
|
||||
if err == nil {
|
||||
t.Error("wrong version, got no error")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
type ActiveUser struct {
|
||||
panel *userPanel
|
||||
|
||||
arrUID [16]byte
|
||||
|
||||
valve mux.Valve
|
||||
|
||||
bypass bool
|
||||
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[uint32]*mux.Session
|
||||
}
|
||||
|
||||
// CloseSession closes a session and removes its reference from the user
|
||||
func (u *ActiveUser) CloseSession(sessionID uint32, reason string) {
|
||||
u.sessionsM.Lock()
|
||||
sesh, existing := u.sessions[sessionID]
|
||||
if existing {
|
||||
delete(u.sessions, sessionID)
|
||||
sesh.SetTerminalMsg(reason)
|
||||
sesh.Close()
|
||||
}
|
||||
remaining := len(u.sessions)
|
||||
u.sessionsM.Unlock()
|
||||
if remaining == 0 {
|
||||
u.panel.TerminateActiveUser(u, "no session left")
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession returns the reference to an existing session, or if one such session doesn't exist, it queries
|
||||
// the UserManager for the authorisation for a new session. If a new session is allowed, it creates this new session
|
||||
// and returns its reference
|
||||
func (u *ActiveUser) GetSession(sessionID uint32, config mux.SessionConfig) (sesh *mux.Session, existing bool, err error) {
|
||||
u.sessionsM.Lock()
|
||||
defer u.sessionsM.Unlock()
|
||||
if sesh = u.sessions[sessionID]; sesh != nil {
|
||||
return sesh, true, nil
|
||||
} else {
|
||||
if !u.bypass {
|
||||
ainfo := usermanager.AuthorisationInfo{NumExistingSessions: len(u.sessions)}
|
||||
err := u.panel.Manager.AuthoriseNewSession(u.arrUID[:], ainfo)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
config.Valve = u.valve
|
||||
sesh = mux.MakeSession(sessionID, config)
|
||||
u.sessions[sessionID] = sesh
|
||||
return sesh, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeAllSessions closes all sessions of this active user
|
||||
func (u *ActiveUser) closeAllSessions(reason string) {
|
||||
u.sessionsM.Lock()
|
||||
for sessionID, sesh := range u.sessions {
|
||||
sesh.SetTerminalMsg(reason)
|
||||
sesh.Close()
|
||||
delete(u.sessions, sessionID)
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
// NumSession returns the number of active sessions
|
||||
func (u *ActiveUser) NumSession() int {
|
||||
u.sessionsM.RLock()
|
||||
defer u.sessionsM.RUnlock()
|
||||
return len(u.sessions)
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
)
|
||||
|
||||
func getSeshConfig(unordered bool) mux.SessionConfig {
|
||||
var sessionKey [32]byte
|
||||
rand.Read(sessionKey[:])
|
||||
obfuscator, _ := mux.MakeObfuscator(0x00, sessionKey)
|
||||
|
||||
seshConfig := mux.SessionConfig{
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: unordered,
|
||||
}
|
||||
return seshConfig
|
||||
}
|
||||
|
||||
func TestActiveUser_Bypass(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
|
||||
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState)
|
||||
if err != nil {
|
||||
t.Fatal("failed to make local manager", err)
|
||||
}
|
||||
panel := MakeUserPanel(manager)
|
||||
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
|
||||
user, _ := panel.GetBypassUser(UID)
|
||||
var sesh0 *mux.Session
|
||||
var existing bool
|
||||
var sesh1 *mux.Session
|
||||
|
||||
// get first session
|
||||
sesh0, existing, err = user.GetSession(0, getSeshConfig(false))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if existing {
|
||||
t.Fatal("get first session: first session returned as existing")
|
||||
}
|
||||
if sesh0 == nil {
|
||||
t.Fatal("get first session: no session returned")
|
||||
}
|
||||
|
||||
// get first session again
|
||||
seshx, existing, err := user.GetSession(0, mux.SessionConfig{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !existing {
|
||||
t.Fatal("get first session again: first session get again returned as not existing")
|
||||
}
|
||||
if seshx == nil {
|
||||
t.Fatal("get first session again: no session returned")
|
||||
}
|
||||
if seshx != sesh0 {
|
||||
t.Fatal("returned a different instance")
|
||||
}
|
||||
|
||||
// get second session
|
||||
sesh1, existing, err = user.GetSession(1, getSeshConfig(false))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if existing {
|
||||
t.Fatal("get second session: second session returned as existing")
|
||||
}
|
||||
if sesh1 == nil {
|
||||
t.Fatal("get second session: no session returned")
|
||||
}
|
||||
|
||||
if user.NumSession() != 2 {
|
||||
t.Fatal("number of session is not 2")
|
||||
}
|
||||
|
||||
user.CloseSession(0, "")
|
||||
if user.NumSession() != 1 {
|
||||
t.Fatal("number of session is not 1 after deleting one")
|
||||
}
|
||||
if !sesh0.IsClosed() {
|
||||
t.Fatal("session not closed after deletion")
|
||||
}
|
||||
|
||||
user.closeAllSessions("")
|
||||
if !sesh1.IsClosed() {
|
||||
t.Fatal("session not closed after user termination")
|
||||
}
|
||||
|
||||
// get session again after termination
|
||||
seshy, existing, err := user.GetSession(0, getSeshConfig(false))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if existing {
|
||||
t.Fatal("get session again after termination: session returned as existing")
|
||||
}
|
||||
if seshy == nil {
|
||||
t.Fatal("get session again after termination: no session returned")
|
||||
}
|
||||
if seshy == sesh0 || seshy == sesh1 {
|
||||
t.Fatal("get session after termination returned the same instance")
|
||||
}
|
||||
|
||||
user.CloseSession(0, "")
|
||||
if panel.isActive(user.arrUID[:]) {
|
||||
t.Fatal("user still active after last session deleted")
|
||||
}
|
||||
|
||||
err = manager.Close()
|
||||
if err != nil {
|
||||
t.Fatal("failed to close localmanager", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,88 +2,59 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"log"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
type ClientInfo struct {
|
||||
UID []byte
|
||||
SessionId uint32
|
||||
ProxyMethod string
|
||||
EncryptionMethod byte
|
||||
Unordered bool
|
||||
Transport Transport
|
||||
// input ticket, return UID
|
||||
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, uint32) {
|
||||
ephPub, _ := ecdh.Unmarshal(ticket[0:32])
|
||||
key := ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||
UIDsID := util.AESDecrypt(ticket[0:16], key, ticket[32:68])
|
||||
sessionID := binary.BigEndian.Uint32(UIDsID[32:36])
|
||||
return UIDsID[0:32], sessionID
|
||||
}
|
||||
|
||||
type authFragments struct {
|
||||
sharedSecret [32]byte
|
||||
randPubKey [32]byte
|
||||
ciphertextWithTag [64]byte
|
||||
func validateRandom(random []byte, UID []byte, time int64) bool {
|
||||
t := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(t, uint64(time/(12*60*60)))
|
||||
rdm := random[0:16]
|
||||
preHash := make([]byte, 56)
|
||||
copy(preHash[0:32], UID)
|
||||
copy(preHash[32:40], t)
|
||||
copy(preHash[40:56], rdm)
|
||||
h := sha256.New()
|
||||
h.Write(preHash)
|
||||
return bytes.Equal(h.Sum(nil)[0:16], random[16:32])
|
||||
}
|
||||
func TouchStone(ch *ClientHello, sta *State) (isSS bool, UID []byte, sessionID uint32) {
|
||||
var random [32]byte
|
||||
copy(random[:], ch.random)
|
||||
|
||||
const (
|
||||
UNORDERED_FLAG = 0x01 // 0000 0001
|
||||
)
|
||||
sta.usedRandomM.Lock()
|
||||
used := sta.usedRandom[random]
|
||||
sta.usedRandom[random] = int(sta.Now().Unix())
|
||||
sta.usedRandomM.Unlock()
|
||||
|
||||
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
|
||||
|
||||
// decryptClientInfo checks if a the authFragments are valid. It doesn't check if the UID is authorised
|
||||
func decryptClientInfo(fragments authFragments, serverTime time.Time) (info ClientInfo, err error) {
|
||||
var plaintext []byte
|
||||
plaintext, err = common.AESGCMDecrypt(fragments.randPubKey[0:12], fragments.sharedSecret[:], fragments.ciphertextWithTag[:])
|
||||
if err != nil {
|
||||
return
|
||||
if used != 0 {
|
||||
log.Println("Replay! Duplicate random")
|
||||
return false, nil, 0
|
||||
}
|
||||
|
||||
info = ClientInfo{
|
||||
UID: plaintext[0:16],
|
||||
SessionId: 0,
|
||||
ProxyMethod: string(bytes.Trim(plaintext[16:28], "\x00")),
|
||||
EncryptionMethod: plaintext[28],
|
||||
Unordered: plaintext[41]&UNORDERED_FLAG != 0,
|
||||
ticket := ch.extensions[[2]byte{0x00, 0x23}]
|
||||
if len(ticket) < 68 {
|
||||
return false, nil, 0
|
||||
}
|
||||
UID, sessionID = decryptSessionTicket(sta.staticPv, ticket)
|
||||
isSS = validateRandom(ch.random, UID, sta.Now().Unix())
|
||||
if !isSS {
|
||||
return false, nil, 0
|
||||
}
|
||||
|
||||
timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37]))
|
||||
clientTime := time.Unix(timestamp, 0)
|
||||
if !(clientTime.After(serverTime.Add(-timestampTolerance)) && clientTime.Before(serverTime.Add(timestampTolerance))) {
|
||||
err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp)
|
||||
return
|
||||
}
|
||||
info.SessionId = binary.BigEndian.Uint32(plaintext[37:41])
|
||||
return
|
||||
}
|
||||
|
||||
var ErrReplay = errors.New("duplicate random")
|
||||
var ErrBadProxyMethod = errors.New("invalid proxy method")
|
||||
var ErrBadDecryption = errors.New("decryption/authentication failure")
|
||||
|
||||
// AuthFirstPacket checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client
|
||||
// if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user
|
||||
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
|
||||
// the handshake
|
||||
func AuthFirstPacket(firstPacket []byte, transport Transport, sta *State) (info ClientInfo, finisher Responder, err error) {
|
||||
fragments, finisher, err := transport.processFirstPacket(firstPacket, sta.StaticPv)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if sta.registerRandom(fragments.randPubKey) {
|
||||
err = ErrReplay
|
||||
return
|
||||
}
|
||||
|
||||
info, err = decryptClientInfo(fragments, sta.WorldState.Now().UTC())
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
err = fmt.Errorf("%w: %v", ErrBadDecryption, err)
|
||||
return
|
||||
}
|
||||
info.Transport = transport
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,196 +1,68 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
)
|
||||
|
||||
func TestDecryptClientInfo(t *testing.T) {
|
||||
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
|
||||
p, _ := ecdh.Unmarshal(pvBytes)
|
||||
staticPv := p.(crypto.PrivateKey)
|
||||
func TestDecryptSessionTicket(t *testing.T) {
|
||||
UID, _ := hex.DecodeString("26a8e88bcd7c64a69ca051740851d22a6818de2fddafc00882331f1c5a8b866c")
|
||||
sessionID := uint32(42)
|
||||
pvb, _ := hex.DecodeString("083794692e77b28fa2152dfee53142185fd58ea8172d3545fdeeaea97b3c597c")
|
||||
staticPv, _ := ecdh.Unmarshal(pvb)
|
||||
sessionTicket, _ := hex.DecodeString("f586223b50cada583d61dc9bf3d01cc3a45aab4b062ed6a31ead0badb87f7761aab4f9f737a1d8ff2a2aa4d50ceb808844588ee3c8fdf36c33a35ef5003e287337659c8164a7949e9e63623090763fc24d0386c8904e47bdd740e09dd9b395c72de669629c2a865ed581452d23306adf26de0c8a46ee05e3dac876f2bcd9a2de946d319498f579383d06b3e66b3aca05f533fdc5f017eeba45b42080aabd4f71151fa0dfc1b0e23be4ed3abdb47adc0d5740ca7b7689ad34426309fb6984a086")
|
||||
|
||||
t.Run("correct time", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
nineSixSix := time.Unix(1565998966, 0)
|
||||
cinfo, err := decryptClientInfo(ai, nineSixSix)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
if cinfo.SessionId != 3710878841 {
|
||||
t.Errorf("expecting session id 3710878841, got %v", cinfo.SessionId)
|
||||
}
|
||||
})
|
||||
t.Run("roughly correct time", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
nineSixSixP50 := time.Unix(1565998966, 0).Add(50)
|
||||
_, err = decryptClientInfo(ai, nineSixSixP50)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
nineSixSixM50 := time.Unix(1565998966, 0).Add(-50)
|
||||
_, err = decryptClientInfo(ai, nineSixSixM50)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
})
|
||||
t.Run("over interval", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
nineSixSixOver := time.Unix(1565998966, 0).Add(timestampTolerance + 10)
|
||||
_, err = decryptClientInfo(ai, nineSixSixOver)
|
||||
if err == nil {
|
||||
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("under interval", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
nineSixSixUnder := time.Unix(1565998966, 0).Add(-(timestampTolerance + 10))
|
||||
_, err = decryptClientInfo(ai, nineSixSixUnder)
|
||||
if err == nil {
|
||||
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("not cloak psk", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010246010002420303794ae79c6db7a31e67e2ce91b8afcb82995ae79ad1d0dc885f933e4193bf95cd208abd7a70f3b82cc31c02f1c2b94ba74d5222a66695a5cf92a366421d7f5eb9530022fafa130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001d75a5a00000000001e001c0000196c68332e676f6f676c6575736572636f6e74656e742e636f6d00170000ff01000100000a000a0008baba001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029baba000100001d002074bfe93336c364b43cf0879d997b2e11dc97068b86fc90174e0f2bcea1d4ed1c002d00020101002b000b0ababa0304030303020301001b00030200029a9a0001000029010500e000da00d1f6c0918f865390ae3ca33c77f61a1974cb4533456071b214ec018d17dc22845f2f72cf1dba48f9cdc0758803002dda9b964fad5522e82442af7cbbe242241e39233386f2383bce3ced8e16b1ae3f0ef52a706f58e1e6a1bca0cd3b3a2a4c4cb738770b01b56bf3e73c472bf4fb238cab510aa78f8427a3ca99f741aa433f548be460705f43a3abe878cec6ee3158c129406910b93e798e8a7aaffc2e7ff7b8fd872778d3687a0beaa1452fe7ec418070d537344b64d09f6edd053346ff9c9678eef6b8886882aba81d4be11d9df653de35659f93a22ac39399e3ba400021204e22b73261693967a9216fe4a3b004571c53f316309e76671a18d78931b5b072")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fiveOSix := time.Unix(1565999506, 0)
|
||||
cinfo, err := decryptClientInfo(ai, fiveOSix)
|
||||
if err == nil {
|
||||
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("not cloak no psk", func(t *testing.T) {
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303eae4c204a867390a758fcff3afa5803cac3e07011cf0c9f3befc1267445aabee20fc398df698113617f8161cbcb89534efa892088a6c5e49246534e05f790ea36f00220a0a130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001910a0a000000000014001200000f63646e2e62697a69626c652e636f6d00170000ff01000100000a000a0008caca001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029caca000100001d00204c8f1563fb70c261bc0c32c1b568b8d02fab25f4094711e7868b1712751dc754002d00020101002b000b0a2a2a0304030303020301001b00030200026a6a000100001500c9000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
ch, _ := parseClientHello(chBytes)
|
||||
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
|
||||
if err != nil {
|
||||
t.Errorf("expecting no error, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sixOneFive := time.Unix(1565999615, 0)
|
||||
cinfo, err := decryptClientInfo(ai, sixOneFive)
|
||||
if err == nil {
|
||||
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestAuthFirstPacket(t *testing.T) {
|
||||
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
|
||||
p, _ := ecdh.Unmarshal(pvBytes)
|
||||
|
||||
getNewState := func() *State {
|
||||
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1565998966, 0)))
|
||||
sta.StaticPv = p.(crypto.PrivateKey)
|
||||
sta.ProxyBook["shadowsocks"] = nil
|
||||
return sta
|
||||
decryUID, decrySessionID := decryptSessionTicket(staticPv, sessionTicket)
|
||||
if !bytes.Equal(decryUID, UID) {
|
||||
t.Error(
|
||||
"For", "UID",
|
||||
"expecting", fmt.Sprintf("%x", UID),
|
||||
"got", fmt.Sprintf("%x", decryUID),
|
||||
)
|
||||
}
|
||||
if decrySessionID != sessionID {
|
||||
t.Error(
|
||||
"For", "sessionID",
|
||||
"expecting", fmt.Sprintf("%x", sessionID),
|
||||
"got", fmt.Sprintf("%x", decrySessionID),
|
||||
)
|
||||
}
|
||||
|
||||
t.Run("TLS correct", func(t *testing.T) {
|
||||
sta := getNewState()
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
info, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
|
||||
if err != nil {
|
||||
t.Errorf("failed to get client info: %v", err)
|
||||
return
|
||||
}
|
||||
if info.SessionId != 3710878841 {
|
||||
t.Error("failed to get correct session id")
|
||||
return
|
||||
}
|
||||
if info.Transport.(fmt.Stringer).String() != "TLS" {
|
||||
t.Errorf("wrong transport: %v", info.Transport)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("TLS correct but replay", func(t *testing.T) {
|
||||
sta := getNewState()
|
||||
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
_, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
|
||||
if err != nil {
|
||||
t.Error("failed to prepare for the first time")
|
||||
return
|
||||
}
|
||||
_, _, err = AuthFirstPacket(chBytes, TLS{}, sta)
|
||||
if err != ErrReplay {
|
||||
t.Errorf("failed to return ErrReplay, got %v instead", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
t.Run("Websocket correct", func(t *testing.T) {
|
||||
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1584358419, 0)))
|
||||
sta.StaticPv = p.(crypto.PrivateKey)
|
||||
sta.ProxyBook["shadowsocks"] = nil
|
||||
}
|
||||
|
||||
req := `GET / HTTP/1.1
|
||||
Host: d2jkinvisak5y9.cloudfront.net:443
|
||||
User-Agent: Go-http-client/1.1
|
||||
Connection: Upgrade
|
||||
Hidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ
|
||||
Sec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==
|
||||
Sec-WebSocket-Version: 13
|
||||
Upgrade: websocket
|
||||
func TestValidateRandom(t *testing.T) {
|
||||
UID, _ := hex.DecodeString("26a8e88bcd7c64a69ca051740851d22a6818de2fddafc00882331f1c5a8b866c")
|
||||
random, _ := hex.DecodeString("6274de9992a6f96a86fc35cf6644a5e7844951889a802e9531add440eabb939b")
|
||||
right := validateRandom(random, UID, 1547912444)
|
||||
if !right {
|
||||
t.Error(
|
||||
"For", fmt.Sprintf("good random: %x at time %v", random, 1547912444),
|
||||
"expecting", true,
|
||||
"got", false,
|
||||
)
|
||||
}
|
||||
|
||||
`
|
||||
info, _, err := AuthFirstPacket([]byte(req), WebSocket{}, sta)
|
||||
if err != nil {
|
||||
t.Errorf("failed to get client info: %v", err)
|
||||
return
|
||||
}
|
||||
if info.Transport.(fmt.Stringer).String() != "WebSocket" {
|
||||
t.Errorf("wrong transport: %v", info.Transport)
|
||||
return
|
||||
}
|
||||
})
|
||||
replay := validateRandom(random, UID, 1547955645)
|
||||
if replay {
|
||||
t.Error(
|
||||
"For", fmt.Sprintf("expired random: %x at time %v", random, 1547955645),
|
||||
"expecting", false,
|
||||
"got", true,
|
||||
)
|
||||
}
|
||||
|
||||
random[13] = 0x42
|
||||
bogus := validateRandom(random, UID, 1547912444)
|
||||
if bogus {
|
||||
t.Error(
|
||||
"For", fmt.Sprintf("bogus random: %x at time %v", random, 1547912444),
|
||||
"expecting", false,
|
||||
"got", true,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,311 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var b64 = base64.StdEncoding.EncodeToString
|
||||
|
||||
const firstPacketSize = 3000
|
||||
|
||||
func Serve(l net.Listener, sta *State) {
|
||||
waitDur := [10]time.Duration{
|
||||
50 * time.Millisecond, 100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
|
||||
3 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second, 30 * time.Second}
|
||||
|
||||
fails := 0
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Errorf("%v, retrying", err)
|
||||
time.Sleep(waitDur[fails])
|
||||
if fails < 9 {
|
||||
fails++
|
||||
}
|
||||
continue
|
||||
}
|
||||
fails = 0
|
||||
go dispatchConnection(conn, sta)
|
||||
}
|
||||
}
|
||||
|
||||
func connReadLine(conn net.Conn, buf []byte) (int, error) {
|
||||
i := 0
|
||||
for ; i < len(buf); i++ {
|
||||
_, err := io.ReadFull(conn, buf[i:i+1])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
if buf[i] == '\n' {
|
||||
return i + 1, nil
|
||||
}
|
||||
}
|
||||
return i, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
var ErrUnrecognisedProtocol = errors.New("unrecognised protocol")
|
||||
|
||||
func readFirstPacket(conn net.Conn, buf []byte, timeout time.Duration) (int, Transport, bool, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
|
||||
_, err := io.ReadFull(conn, buf[:1])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read error after connection is established: %v", err)
|
||||
conn.Close()
|
||||
return 0, nil, false, err
|
||||
}
|
||||
|
||||
// TODO: give the option to match the protocol with port
|
||||
bufOffset := 1
|
||||
var transport Transport
|
||||
switch buf[0] {
|
||||
case 0x16:
|
||||
transport = TLS{}
|
||||
recordLayerLength := 5
|
||||
|
||||
i, err := io.ReadFull(conn, buf[bufOffset:recordLayerLength])
|
||||
bufOffset += i
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read error after connection is established: %v", err)
|
||||
conn.Close()
|
||||
return bufOffset, transport, false, err
|
||||
}
|
||||
dataLength := int(binary.BigEndian.Uint16(buf[3:5]))
|
||||
if dataLength+recordLayerLength > len(buf) {
|
||||
return bufOffset, transport, true, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
i, err = io.ReadFull(conn, buf[recordLayerLength:dataLength+recordLayerLength])
|
||||
bufOffset += i
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read error after connection is established: %v", err)
|
||||
conn.Close()
|
||||
return bufOffset, transport, false, err
|
||||
}
|
||||
case 0x47:
|
||||
transport = WebSocket{}
|
||||
|
||||
for {
|
||||
i, err := connReadLine(conn, buf[bufOffset:])
|
||||
line := buf[bufOffset : bufOffset+i]
|
||||
bufOffset += i
|
||||
if err != nil {
|
||||
if err == io.ErrShortBuffer {
|
||||
return bufOffset, transport, true, err
|
||||
} else {
|
||||
err = fmt.Errorf("error reading first packet: %v", err)
|
||||
conn.Close()
|
||||
return bufOffset, transport, false, err
|
||||
}
|
||||
}
|
||||
|
||||
if bytes.Equal(line, []byte("\r\n")) {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
return bufOffset, transport, true, ErrUnrecognisedProtocol
|
||||
}
|
||||
return bufOffset, transport, true, nil
|
||||
}
|
||||
|
||||
func dispatchConnection(conn net.Conn, sta *State) {
|
||||
var err error
|
||||
buf := make([]byte, firstPacketSize)
|
||||
|
||||
i, transport, redirOnErr, err := readFirstPacket(conn, buf, 15*time.Second)
|
||||
data := buf[:i]
|
||||
|
||||
goWeb := func() {
|
||||
redirPort := sta.RedirPort
|
||||
if redirPort == "" {
|
||||
_, redirPort, _ = net.SplitHostPort(conn.LocalAddr().String())
|
||||
}
|
||||
webConn, err := sta.RedirDialer.Dial("tcp", net.JoinHostPort(sta.RedirHost.String(), redirPort))
|
||||
if err != nil {
|
||||
log.Errorf("Making connection to redirection server: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = webConn.Write(data)
|
||||
if err != nil {
|
||||
log.Error("Failed to send first packet to redirection server", err)
|
||||
return
|
||||
}
|
||||
go common.Copy(webConn, conn)
|
||||
go common.Copy(conn, webConn)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.WithField("remoteAddr", conn.RemoteAddr()).
|
||||
Warnf("error reading first packet: %v", err)
|
||||
if redirOnErr {
|
||||
goWeb()
|
||||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ci, finishHandshake, err := AuthFirstPacket(data, transport, sta)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"remoteAddr": conn.RemoteAddr(),
|
||||
"UID": b64(ci.UID),
|
||||
"sessionId": ci.SessionId,
|
||||
"proxyMethod": ci.ProxyMethod,
|
||||
"encryptionMethod": ci.EncryptionMethod,
|
||||
}).Warn(err)
|
||||
goWeb()
|
||||
return
|
||||
}
|
||||
|
||||
var sessionKey [32]byte
|
||||
common.RandRead(sta.WorldState.Rand, sessionKey[:])
|
||||
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"remoteAddr": conn.RemoteAddr(),
|
||||
"UID": b64(ci.UID),
|
||||
"sessionId": ci.SessionId,
|
||||
"proxyMethod": ci.ProxyMethod,
|
||||
"encryptionMethod": ci.EncryptionMethod,
|
||||
}).Error(err)
|
||||
goWeb()
|
||||
return
|
||||
}
|
||||
|
||||
seshConfig := mux.SessionConfig{
|
||||
Obfuscator: obfuscator,
|
||||
Valve: nil,
|
||||
Unordered: ci.Unordered,
|
||||
MsgOnWireSizeLimit: appDataMaxLength,
|
||||
}
|
||||
|
||||
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
|
||||
// added to the userinfo database. The distinction between going into the admin mode
|
||||
// and normal proxy mode is that sessionID needs == 0 for admin mode
|
||||
if len(sta.AdminUID) != 0 && bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
|
||||
sesh := mux.MakeSession(0, seshConfig)
|
||||
preparedConn, err := finishHandshake(conn, sessionKey, sta.WorldState.Rand)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
log.Trace("finished handshake")
|
||||
sesh.AddConnection(preparedConn)
|
||||
//TODO: Router could be nil in cnc mode
|
||||
log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session")
|
||||
err = http.Serve(sesh, usermanager.APIRouterOf(sta.Panel.Manager))
|
||||
// http.Serve never returns with non-nil error
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := sta.ProxyBook[ci.ProxyMethod]; !ok {
|
||||
log.WithFields(log.Fields{
|
||||
"remoteAddr": conn.RemoteAddr(),
|
||||
"UID": b64(ci.UID),
|
||||
"sessionId": ci.SessionId,
|
||||
"proxyMethod": ci.ProxyMethod,
|
||||
"encryptionMethod": ci.EncryptionMethod,
|
||||
}).Error(ErrBadProxyMethod)
|
||||
goWeb()
|
||||
return
|
||||
}
|
||||
|
||||
var user *ActiveUser
|
||||
if sta.IsBypass(ci.UID) {
|
||||
user, err = sta.Panel.GetBypassUser(ci.UID)
|
||||
} else {
|
||||
user, err = sta.Panel.GetUser(ci.UID)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"UID": b64(ci.UID),
|
||||
"remoteAddr": conn.RemoteAddr(),
|
||||
"error": err,
|
||||
}).Warn("+1 unauthorised UID")
|
||||
goWeb()
|
||||
return
|
||||
}
|
||||
|
||||
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
|
||||
if err != nil {
|
||||
user.CloseSession(ci.SessionId, "")
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
preparedConn, err := finishHandshake(conn, sesh.GetSessionKey(), sta.WorldState.Rand)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
log.Trace("finished handshake")
|
||||
sesh.AddConnection(preparedConn)
|
||||
|
||||
if !existing {
|
||||
// if the session was newly made, we serve connections from the session streams to the proxy server
|
||||
log.WithFields(log.Fields{
|
||||
"UID": b64(ci.UID),
|
||||
"sessionID": ci.SessionId,
|
||||
}).Info("New session")
|
||||
|
||||
serveSession(sesh, ci, user, sta)
|
||||
}
|
||||
}
|
||||
|
||||
func serveSession(sesh *mux.Session, ci ClientInfo, user *ActiveUser, sta *State) error {
|
||||
for {
|
||||
newStream, err := sesh.Accept()
|
||||
if err != nil {
|
||||
if err == mux.ErrBrokenSession {
|
||||
log.WithFields(log.Fields{
|
||||
"UID": b64(ci.UID),
|
||||
"sessionID": ci.SessionId,
|
||||
"reason": sesh.TerminalMsg(),
|
||||
}).Info("Session closed")
|
||||
user.CloseSession(ci.SessionId, "")
|
||||
return nil
|
||||
} else {
|
||||
log.Errorf("unhandled error on session.Accept(): %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
proxyAddr := sta.ProxyBook[ci.ProxyMethod]
|
||||
localConn, err := sta.ProxyDialer.Dial(proxyAddr.Network(), proxyAddr.String())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
|
||||
user.CloseSession(ci.SessionId, "Failed to connect to proxy server")
|
||||
return err
|
||||
}
|
||||
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
|
||||
|
||||
go func() {
|
||||
if _, err := common.Copy(localConn, newStream); err != nil {
|
||||
log.Tracef("copying stream to proxy server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := common.Copy(newStream, localConn); err != nil {
|
||||
log.Tracef("copying proxy server to stream: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,211 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/connutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type rfpReturnValue struct {
|
||||
n int
|
||||
transport Transport
|
||||
redirOnErr bool
|
||||
err error
|
||||
}
|
||||
|
||||
const timeout = 500 * time.Millisecond
|
||||
|
||||
func TestReadFirstPacket(t *testing.T) {
|
||||
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue) {
|
||||
ret := rfpReturnValue{}
|
||||
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, timeout)
|
||||
retChan <- ret
|
||||
}
|
||||
|
||||
t.Run("Good TLS", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
local.Write(first)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, len(first), ret.n)
|
||||
assert.Equal(t, first, buf[:ret.n])
|
||||
assert.IsType(t, TLS{}, ret.transport)
|
||||
assert.NoError(t, ret.err)
|
||||
})
|
||||
|
||||
t.Run("Good TLS but buf too small", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 10)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
local.Write(first)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, io.ErrShortBuffer, ret.err)
|
||||
assert.True(t, ret.redirOnErr)
|
||||
assert.Equal(t, first[:ret.n], buf[:ret.n])
|
||||
|
||||
})
|
||||
|
||||
t.Run("Incomplete timeout", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("160301")
|
||||
local.Write(first)
|
||||
select {
|
||||
case ret := <-retChan:
|
||||
assert.Equal(t, len(first), ret.n)
|
||||
assert.False(t, ret.redirOnErr)
|
||||
assert.Error(t, ret.err)
|
||||
case <-time.After(2 * timeout):
|
||||
assert.Fail(t, "readFirstPacket should have timed out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Incomplete payload timeout", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("16030101010000")
|
||||
local.Write(first)
|
||||
select {
|
||||
case ret := <-retChan:
|
||||
assert.Equal(t, len(first), ret.n)
|
||||
assert.False(t, ret.redirOnErr)
|
||||
assert.Error(t, ret.err)
|
||||
case <-time.After(2 * timeout):
|
||||
assert.Fail(t, "readFirstPacket should have timed out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Good TLS staggered", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
local.Write(first[:100])
|
||||
time.Sleep(timeout / 2)
|
||||
local.Write(first[100:])
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, len(first), ret.n)
|
||||
assert.Equal(t, first, buf[:ret.n])
|
||||
assert.IsType(t, TLS{}, ret.transport)
|
||||
assert.NoError(t, ret.err)
|
||||
})
|
||||
|
||||
t.Run("TLS bad recordlayer length", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
first, _ := hex.DecodeString("160301ffff")
|
||||
local.Write(first)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, len(first), ret.n)
|
||||
assert.Equal(t, first, buf[:ret.n])
|
||||
assert.IsType(t, TLS{}, ret.transport)
|
||||
assert.Equal(t, io.ErrShortBuffer, ret.err)
|
||||
assert.True(t, ret.redirOnErr)
|
||||
})
|
||||
|
||||
t.Run("Good WebSocket", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
|
||||
req := []byte(reqStr)
|
||||
local.Write(req)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, len(req), ret.n)
|
||||
assert.Equal(t, req, buf[:ret.n])
|
||||
assert.IsType(t, WebSocket{}, ret.transport)
|
||||
assert.NoError(t, ret.err)
|
||||
})
|
||||
|
||||
t.Run("Good WebSocket but buf too small", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 10)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
|
||||
req := []byte(reqStr)
|
||||
local.Write(req)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, io.ErrShortBuffer, ret.err)
|
||||
assert.True(t, ret.redirOnErr)
|
||||
assert.Equal(t, req[:ret.n], buf[:ret.n])
|
||||
})
|
||||
|
||||
t.Run("Incomplete WebSocket timeout", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
reqStr := "GET /"
|
||||
req := []byte(reqStr)
|
||||
local.Write(req)
|
||||
|
||||
select {
|
||||
case ret := <-retChan:
|
||||
assert.Equal(t, len(req), ret.n)
|
||||
assert.False(t, ret.redirOnErr)
|
||||
assert.Error(t, ret.err)
|
||||
case <-time.After(2 * timeout):
|
||||
assert.Fail(t, "readFirstPacket should have timed out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Staggered WebSocket", func(t *testing.T) {
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
|
||||
req := []byte(reqStr)
|
||||
local.Write(req[:100])
|
||||
time.Sleep(timeout / 2)
|
||||
local.Write(req[100:])
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
assert.Equal(t, len(req), ret.n)
|
||||
assert.Equal(t, req, buf[:ret.n])
|
||||
assert.IsType(t, WebSocket{}, ret.transport)
|
||||
assert.NoError(t, ret.err)
|
||||
})
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
//go:build gofuzz
|
||||
// +build gofuzz
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/connutil"
|
||||
)
|
||||
|
||||
type rfpReturnValue_fuzz struct {
|
||||
n int
|
||||
transport Transport
|
||||
redirOnErr bool
|
||||
err error
|
||||
}
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
var bypassUID [16]byte
|
||||
|
||||
var pv [32]byte
|
||||
|
||||
sta := &State{
|
||||
BypassUID: map[[16]byte]struct{}{
|
||||
bypassUID: {},
|
||||
},
|
||||
ProxyBook: map[string]net.Addr{
|
||||
"shadowsocks": nil,
|
||||
},
|
||||
UsedRandom: map[[32]byte]int64{},
|
||||
StaticPv: &pv,
|
||||
WorldState: common.RealWorldState,
|
||||
}
|
||||
|
||||
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue_fuzz) {
|
||||
ret := rfpReturnValue_fuzz{}
|
||||
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, 500*time.Millisecond)
|
||||
retChan <- ret
|
||||
}
|
||||
|
||||
local, remote := connutil.AsyncPipe()
|
||||
buf := make([]byte, 1500)
|
||||
retChan := make(chan rfpReturnValue_fuzz)
|
||||
go rfp(remote, buf, retChan)
|
||||
|
||||
local.Write(data)
|
||||
|
||||
ret := <-retChan
|
||||
|
||||
if ret.err != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
_, _, err := AuthFirstPacket(buf[:ret.n], ret.transport, sta)
|
||||
|
||||
if !errors.Is(err, ErrReplay) && !errors.Is(err, ErrBadDecryption) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
@ -2,232 +2,132 @@ package server
|
|||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
)
|
||||
|
||||
type RawConfig struct {
|
||||
ProxyBook map[string][]string
|
||||
BindAddr []string
|
||||
BypassUID [][]byte
|
||||
RedirAddr string
|
||||
PrivateKey []byte
|
||||
AdminUID []byte
|
||||
DatabasePath string
|
||||
KeepAlive int
|
||||
CncMode bool
|
||||
type rawConfig struct {
|
||||
WebServerAddr string
|
||||
PrivateKey string
|
||||
AdminUID string
|
||||
DatabasePath string
|
||||
BackupDirPath string
|
||||
}
|
||||
|
||||
// State type stores the global state of the program
|
||||
type State struct {
|
||||
ProxyBook map[string]net.Addr
|
||||
ProxyDialer common.Dialer
|
||||
|
||||
WorldState common.WorldState
|
||||
AdminUID []byte
|
||||
|
||||
BypassUID map[[16]byte]struct{}
|
||||
StaticPv crypto.PrivateKey
|
||||
|
||||
// TODO: this doesn't have to be a net.Addr; resolution is done in Dial automatically
|
||||
RedirHost net.Addr
|
||||
RedirPort string
|
||||
RedirDialer common.Dialer
|
||||
SS_LOCAL_HOST string
|
||||
SS_LOCAL_PORT string
|
||||
SS_REMOTE_HOST string
|
||||
SS_REMOTE_PORT string
|
||||
|
||||
Now func() time.Time
|
||||
AdminUID []byte
|
||||
staticPv crypto.PrivateKey
|
||||
Userpanel *usermanager.Userpanel
|
||||
usedRandomM sync.RWMutex
|
||||
UsedRandom map[[32]byte]int64
|
||||
usedRandom map[[32]byte]int
|
||||
|
||||
Panel *userPanel
|
||||
WebServerAddr string
|
||||
}
|
||||
|
||||
func parseRedirAddr(redirAddr string) (net.Addr, string, error) {
|
||||
var host string
|
||||
var port string
|
||||
colonSep := strings.Split(redirAddr, ":")
|
||||
if len(colonSep) > 1 {
|
||||
if len(colonSep) == 2 {
|
||||
// domain or ipv4 with port
|
||||
host = colonSep[0]
|
||||
port = colonSep[1]
|
||||
} else {
|
||||
if strings.Contains(redirAddr, "[") {
|
||||
// ipv6 with port
|
||||
port = colonSep[len(colonSep)-1]
|
||||
host = strings.TrimSuffix(redirAddr, "]:"+port)
|
||||
host = strings.TrimPrefix(host, "[")
|
||||
} else {
|
||||
// ipv6 without port
|
||||
host = redirAddr
|
||||
}
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) (*State, error) {
|
||||
ret := &State{
|
||||
SS_LOCAL_HOST: localHost,
|
||||
SS_LOCAL_PORT: localPort,
|
||||
SS_REMOTE_HOST: remoteHost,
|
||||
SS_REMOTE_PORT: remotePort,
|
||||
Now: nowFunc,
|
||||
}
|
||||
ret.usedRandom = make(map[[32]byte]int)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// semi-colon separated value.
|
||||
func ssvToJson(ssv string) (ret []byte) {
|
||||
unescape := func(s string) string {
|
||||
r := strings.Replace(s, `\\`, `\`, -1)
|
||||
r = strings.Replace(r, `\=`, `=`, -1)
|
||||
r = strings.Replace(r, `\;`, `;`, -1)
|
||||
return r
|
||||
}
|
||||
lines := strings.Split(unescape(ssv), ";")
|
||||
ret = []byte("{")
|
||||
for _, ln := range lines {
|
||||
if ln == "" {
|
||||
break
|
||||
}
|
||||
sp := strings.SplitN(ln, "=", 2)
|
||||
key := sp[0]
|
||||
value := sp[1]
|
||||
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
|
||||
|
||||
}
|
||||
ret = ret[:len(ret)-1] // remove the last comma
|
||||
ret = append(ret, '}')
|
||||
return ret
|
||||
}
|
||||
|
||||
// ParseConfig parses the config (either a path to json or in-line ssv config) into a State variable
|
||||
func (sta *State) ParseConfig(conf string) (err error) {
|
||||
var content []byte
|
||||
if strings.Contains(conf, ";") && strings.Contains(conf, "=") {
|
||||
content = ssvToJson(conf)
|
||||
} else {
|
||||
// domain or ipv4 without port
|
||||
host = redirAddr
|
||||
content, err = ioutil.ReadFile(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
redirHost, err := net.ResolveIPAddr("ip", host)
|
||||
var preParse rawConfig
|
||||
err = json.Unmarshal(content, &preParse)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unable to resolve RedirAddr: %v. ", err)
|
||||
}
|
||||
return redirHost, port, nil
|
||||
}
|
||||
|
||||
func parseProxyBook(bookEntries map[string][]string) (map[string]net.Addr, error) {
|
||||
proxyBook := map[string]net.Addr{}
|
||||
for name, pair := range bookEntries {
|
||||
name = strings.ToLower(name)
|
||||
if len(pair) != 2 {
|
||||
return nil, fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair)
|
||||
}
|
||||
network := strings.ToLower(pair[0])
|
||||
switch network {
|
||||
case "tcp":
|
||||
addr, err := net.ResolveTCPAddr("tcp", pair[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxyBook[name] = addr
|
||||
continue
|
||||
case "udp":
|
||||
addr, err := net.ResolveUDPAddr("udp", pair[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxyBook[name] = addr
|
||||
continue
|
||||
}
|
||||
}
|
||||
return proxyBook, nil
|
||||
}
|
||||
|
||||
// ParseConfig reads the config file or semicolon-separated options and parse them into a RawConfig
|
||||
func ParseConfig(conf string) (raw RawConfig, err error) {
|
||||
content, errPath := ioutil.ReadFile(conf)
|
||||
if errPath != nil {
|
||||
errJson := json.Unmarshal(content, &raw)
|
||||
if errJson != nil {
|
||||
err = fmt.Errorf("failed to read/unmarshal configuration, path is invalid or %v", errJson)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
errJson := json.Unmarshal(content, &raw)
|
||||
if errJson != nil {
|
||||
err = fmt.Errorf("failed to read configuration file: %v", errJson)
|
||||
return
|
||||
}
|
||||
}
|
||||
if raw.ProxyBook == nil {
|
||||
raw.ProxyBook = make(map[string][]string)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InitState process the RawConfig and initialises a server State accordingly
|
||||
func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, err error) {
|
||||
sta = &State{
|
||||
BypassUID: make(map[[16]byte]struct{}),
|
||||
ProxyBook: map[string]net.Addr{},
|
||||
UsedRandom: map[[32]byte]int64{},
|
||||
RedirDialer: &net.Dialer{},
|
||||
WorldState: worldState,
|
||||
}
|
||||
if preParse.CncMode {
|
||||
err = errors.New("command & control mode not implemented")
|
||||
return
|
||||
} else {
|
||||
var manager usermanager.UserManager
|
||||
if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" {
|
||||
manager = &usermanager.Voidmanager{}
|
||||
} else {
|
||||
manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
|
||||
if err != nil {
|
||||
return sta, err
|
||||
}
|
||||
}
|
||||
sta.Panel = MakeUserPanel(manager)
|
||||
return errors.New("Failed to unmarshal: " + err.Error())
|
||||
}
|
||||
|
||||
if preParse.KeepAlive <= 0 {
|
||||
sta.ProxyDialer = &net.Dialer{KeepAlive: -1}
|
||||
} else {
|
||||
sta.ProxyDialer = &net.Dialer{KeepAlive: time.Duration(preParse.KeepAlive) * time.Second}
|
||||
}
|
||||
|
||||
sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr)
|
||||
up, err := usermanager.MakeUserpanel(preParse.DatabasePath, preParse.BackupDirPath)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse RedirAddr: %v", err)
|
||||
return
|
||||
return errors.New("Attempting to open database: " + err.Error())
|
||||
}
|
||||
sta.Userpanel = up
|
||||
|
||||
sta.ProxyBook, err = parseProxyBook(preParse.ProxyBook)
|
||||
sta.WebServerAddr = preParse.WebServerAddr
|
||||
|
||||
pvBytes, err := base64.StdEncoding.DecodeString(preParse.PrivateKey)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse ProxyBook: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(preParse.PrivateKey) == 0 {
|
||||
err = fmt.Errorf("must have a valid private key. Run `ck-server -key` to generate one")
|
||||
return
|
||||
return errors.New("Failed to decode private key: " + err.Error())
|
||||
}
|
||||
var pv [32]byte
|
||||
copy(pv[:], preParse.PrivateKey)
|
||||
sta.StaticPv = &pv
|
||||
copy(pv[:], pvBytes)
|
||||
sta.staticPv = &pv
|
||||
|
||||
sta.AdminUID = preParse.AdminUID
|
||||
|
||||
var arrUID [16]byte
|
||||
for _, UID := range preParse.BypassUID {
|
||||
copy(arrUID[:], UID)
|
||||
sta.BypassUID[arrUID] = struct{}{}
|
||||
adminUID, err := base64.StdEncoding.DecodeString(preParse.AdminUID)
|
||||
if err != nil {
|
||||
return errors.New("Failed to decode AdminUID: " + err.Error())
|
||||
}
|
||||
if len(sta.AdminUID) != 0 {
|
||||
copy(arrUID[:], sta.AdminUID)
|
||||
sta.BypassUID[arrUID] = struct{}{}
|
||||
}
|
||||
|
||||
go sta.UsedRandomCleaner()
|
||||
return sta, nil
|
||||
sta.AdminUID = adminUID
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsBypass checks if a UID is a bypass user
|
||||
func (sta *State) IsBypass(UID []byte) bool {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
_, exist := sta.BypassUID[arrUID]
|
||||
return exist
|
||||
}
|
||||
|
||||
const timestampTolerance = 180 * time.Second
|
||||
|
||||
const replayCacheAgeLimit = 12 * time.Hour
|
||||
|
||||
// UsedRandomCleaner clears the cache of used random fields every replayCacheAgeLimit
|
||||
// UsedRandomCleaner clears the cache of used random fields every 12 hours
|
||||
func (sta *State) UsedRandomCleaner() {
|
||||
for {
|
||||
time.Sleep(replayCacheAgeLimit)
|
||||
time.Sleep(12 * time.Hour)
|
||||
now := int(sta.Now().Unix())
|
||||
sta.usedRandomM.Lock()
|
||||
for key, t := range sta.UsedRandom {
|
||||
if time.Unix(t, 0).Before(sta.WorldState.Now().Add(timestampTolerance)) {
|
||||
delete(sta.UsedRandom, key)
|
||||
for key, t := range sta.usedRandom {
|
||||
if now-t > 12*3600 {
|
||||
delete(sta.usedRandom, key)
|
||||
}
|
||||
}
|
||||
sta.usedRandomM.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (sta *State) registerRandom(r [32]byte) bool {
|
||||
sta.usedRandomM.Lock()
|
||||
_, used := sta.UsedRandom[r]
|
||||
sta.UsedRandom[r] = sta.WorldState.Now().Unix()
|
||||
sta.usedRandomM.Unlock()
|
||||
return used
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,126 +1,20 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseRedirAddr(t *testing.T) {
|
||||
t.Run("ipv4 without port", func(t *testing.T) {
|
||||
ipv4noPort := "1.2.3.4"
|
||||
host, port, err := parseRedirAddr(ipv4noPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", ipv4noPort, err)
|
||||
return
|
||||
}
|
||||
if host.String() != "1.2.3.4" {
|
||||
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
|
||||
}
|
||||
if port != "" {
|
||||
t.Errorf("port not empty when there is no port")
|
||||
}
|
||||
})
|
||||
func TestSSVtoJson(t *testing.T) {
|
||||
ssv := "WebServerAddr=204.79.197.200:443;PrivateKey=EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=;AdminUID=ugDmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=;DatabasePath=userinfo.db;BackupDirPath=;"
|
||||
json := ssvToJson(ssv)
|
||||
expected := []byte(`{"WebServerAddr":"204.79.197.200:443","PrivateKey":"EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=","AdminUID":"ugDmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=","DatabasePath":"userinfo.db","BackupDirPath":""}`)
|
||||
if !bytes.Equal(expected, json) {
|
||||
t.Error(
|
||||
"For", "ssvToJson",
|
||||
"expecting", string(expected),
|
||||
"got", string(json),
|
||||
)
|
||||
}
|
||||
|
||||
t.Run("ipv4 with port", func(t *testing.T) {
|
||||
ipv4wPort := "1.2.3.4:1234"
|
||||
host, port, err := parseRedirAddr(ipv4wPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", ipv4wPort, err)
|
||||
return
|
||||
}
|
||||
if host.String() != "1.2.3.4" {
|
||||
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
|
||||
}
|
||||
if port != "1234" {
|
||||
t.Errorf("wrong port: expected %v, got %v", "1234", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domain without port", func(t *testing.T) {
|
||||
domainNoPort := "example.com"
|
||||
host, port, err := parseRedirAddr(domainNoPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", domainNoPort, err)
|
||||
return
|
||||
}
|
||||
|
||||
expIPs, err := net.LookupIP("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("tester error: cannot resolve example.com: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
contain := false
|
||||
for _, expIP := range expIPs {
|
||||
if expIP.String() == host.String() {
|
||||
contain = true
|
||||
}
|
||||
}
|
||||
|
||||
if !contain {
|
||||
t.Errorf("expected one of %v got %v", expIPs, host.String())
|
||||
}
|
||||
if port != "" {
|
||||
t.Errorf("port not empty when there is no port")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("domain with port", func(t *testing.T) {
|
||||
domainWPort := "example.com:80"
|
||||
host, port, err := parseRedirAddr(domainWPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", domainWPort, err)
|
||||
return
|
||||
}
|
||||
|
||||
expIPs, err := net.LookupIP("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("tester error: cannot resolve example.com: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
contain := false
|
||||
for _, expIP := range expIPs {
|
||||
if expIP.String() == host.String() {
|
||||
contain = true
|
||||
}
|
||||
}
|
||||
|
||||
if !contain {
|
||||
t.Errorf("expected one of %v got %v", expIPs, host.String())
|
||||
}
|
||||
if port != "80" {
|
||||
t.Errorf("wrong port: expected %v, got %v", "80", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ipv6 without port", func(t *testing.T) {
|
||||
ipv6noPort := "a:b:c:d::"
|
||||
host, port, err := parseRedirAddr(ipv6noPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", ipv6noPort, err)
|
||||
return
|
||||
}
|
||||
if host.String() != "a:b:c:d::" {
|
||||
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
|
||||
}
|
||||
if port != "" {
|
||||
t.Errorf("port not empty when there is no port")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ipv6 with port", func(t *testing.T) {
|
||||
ipv6wPort := "[a:b:c:d::]:80"
|
||||
host, port, err := parseRedirAddr(ipv6wPort)
|
||||
if err != nil {
|
||||
t.Errorf("parsing %v error: %v", ipv6wPort, err)
|
||||
return
|
||||
}
|
||||
if host.String() != "a:b:c:d::" {
|
||||
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
|
||||
}
|
||||
if port != "80" {
|
||||
t.Errorf("wrong port: expected %v, got %v", "80", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Responder = func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error)
|
||||
type Transport interface {
|
||||
processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error)
|
||||
}
|
||||
|
||||
var ErrInvalidPubKey = errors.New("public key has invalid format")
|
||||
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
swagger: '2.0'
|
||||
info:
|
||||
description: |
|
||||
This is the API of Cloak server
|
||||
version: 0.0.2
|
||||
title: Cloak Server
|
||||
contact:
|
||||
email: cbeuw.andy@gmail.com
|
||||
license:
|
||||
name: GPLv3
|
||||
url: https://www.gnu.org/licenses/gpl-3.0.en.html
|
||||
# host: petstore.swagger.io
|
||||
# basePath: /v2
|
||||
tags:
|
||||
- name: users
|
||||
description: Operations related to user controls by admin
|
||||
# schemes:
|
||||
# - http
|
||||
paths:
|
||||
/admin/users:
|
||||
get:
|
||||
tags:
|
||||
- users
|
||||
summary: Show all users
|
||||
description: Returns an array of all UserInfo
|
||||
operationId: listAllUsers
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
200:
|
||||
description: successful operation
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/definitions/UserInfo'
|
||||
500:
|
||||
description: internal error
|
||||
/admin/users/{UID}:
|
||||
get:
|
||||
tags:
|
||||
- users
|
||||
summary: Show userinfo by UID
|
||||
description: Returns a UserInfo object
|
||||
operationId: getUserInfo
|
||||
produces:
|
||||
- application/json
|
||||
parameters:
|
||||
- name: UID
|
||||
in: path
|
||||
description: UID of the user
|
||||
required: true
|
||||
type: string
|
||||
format: byte
|
||||
responses:
|
||||
200:
|
||||
description: successful operation
|
||||
schema:
|
||||
$ref: '#/definitions/UserInfo'
|
||||
400:
|
||||
description: bad request
|
||||
404:
|
||||
description: User not found
|
||||
500:
|
||||
description: internal error
|
||||
post:
|
||||
tags:
|
||||
- users
|
||||
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
|
||||
operationId: writeUserInfo
|
||||
consumes:
|
||||
- application/json
|
||||
produces:
|
||||
- application/json
|
||||
parameters:
|
||||
- name: UID
|
||||
in: path
|
||||
description: UID of the user
|
||||
required: true
|
||||
type: string
|
||||
format: byte
|
||||
- name: UserInfo
|
||||
in: body
|
||||
description: New userinfo
|
||||
required: true
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/definitions/UserInfo'
|
||||
responses:
|
||||
201:
|
||||
description: successful operation
|
||||
400:
|
||||
description: bad request
|
||||
500:
|
||||
description: internal error
|
||||
delete:
|
||||
tags:
|
||||
- users
|
||||
summary: Deletes a user
|
||||
operationId: deleteUser
|
||||
produces:
|
||||
- application/json
|
||||
parameters:
|
||||
- name: UID
|
||||
in: path
|
||||
description: UID of the user to be deleted
|
||||
required: true
|
||||
type: string
|
||||
format: byte
|
||||
responses:
|
||||
200:
|
||||
description: successful operation
|
||||
400:
|
||||
description: bad request
|
||||
404:
|
||||
description: User not found
|
||||
500:
|
||||
description: internal error
|
||||
|
||||
definitions:
|
||||
UserInfo:
|
||||
type: object
|
||||
properties:
|
||||
UID:
|
||||
type: string
|
||||
format: byte
|
||||
SessionsCap:
|
||||
type: integer
|
||||
format: int32
|
||||
UpRate:
|
||||
type: integer
|
||||
format: int64
|
||||
DownRate:
|
||||
type: integer
|
||||
format: int64
|
||||
UpCredit:
|
||||
type: integer
|
||||
format: int64
|
||||
DownCredit:
|
||||
type: integer
|
||||
format: int64
|
||||
ExpiryTime:
|
||||
type: integer
|
||||
format: int64
|
||||
externalDocs:
|
||||
description: Find out more about Swagger
|
||||
url: http://swagger.io
|
||||
# Added by API Auto Mocking Plugin
|
||||
host: 127.0.0.1:8080
|
||||
basePath: /
|
||||
schemes:
|
||||
- http
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
gmux "github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
type APIRouter struct {
|
||||
*gmux.Router
|
||||
manager UserManager
|
||||
}
|
||||
|
||||
func APIRouterOf(manager UserManager) *APIRouter {
|
||||
ret := &APIRouter{
|
||||
manager: manager,
|
||||
}
|
||||
ret.registerMux()
|
||||
return ret
|
||||
}
|
||||
|
||||
func corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (ar *APIRouter) registerMux() {
|
||||
ar.Router = gmux.NewRouter()
|
||||
ar.HandleFunc("/admin/users", ar.listAllUsersHlr).Methods("GET")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.getUserInfoHlr).Methods("GET")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.writeUserInfoHlr).Methods("POST")
|
||||
ar.HandleFunc("/admin/users/{UID}", ar.deleteUserHlr).Methods("DELETE")
|
||||
ar.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
|
||||
})
|
||||
ar.Use(corsMiddleware)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
|
||||
infos, err := ar.manager.ListAllUsers()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp, err := json.Marshal(infos)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
uinfo, err := ar.manager.GetUserInfo(UID)
|
||||
if err == ErrUserNotFound {
|
||||
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
resp, err := json.Marshal(uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(resp)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var uinfo UserInfo
|
||||
err = json.NewDecoder(r.Body).Decode(&uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(UID, uinfo.UID) {
|
||||
http.Error(w, "UID mismatch", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = ar.manager.WriteUserInfo(uinfo)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (ar *APIRouter) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
|
||||
b64UID := gmux.Vars(r)["UID"]
|
||||
if b64UID == "" {
|
||||
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
UID, err := base64.URLEncoding.DecodeString(b64UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = ar.manager.DeleteUser(UID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mockUIDb64 = base64.URLEncoding.EncodeToString(mockUID)
|
||||
|
||||
func makeRouter(t *testing.T) (router *APIRouter, cleaner func()) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
cleaner = func() { os.Remove(tmpDB.Name()) }
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
router = APIRouterOf(mgr)
|
||||
return router, cleaner
|
||||
}
|
||||
|
||||
func TestWriteUserInfoHlr(t *testing.T) {
|
||||
router, cleaner := makeRouter(t)
|
||||
defer cleaner()
|
||||
|
||||
marshalled, err := json.Marshal(mockUserInfo)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
|
||||
})
|
||||
|
||||
t.Run("partial update", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
|
||||
assert.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||
|
||||
partialUserInfo := UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: JustInt32(10),
|
||||
}
|
||||
partialMarshalled, _ := json.Marshal(partialUserInfo)
|
||||
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
|
||||
assert.NoError(t, err)
|
||||
router.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||
|
||||
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
|
||||
assert.NoError(t, err)
|
||||
router.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||
var got UserInfo
|
||||
err = json.Unmarshal(rr.Body.Bytes(), &got)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := mockUserInfo
|
||||
expected.SessionsCap = partialUserInfo.SessionsCap
|
||||
assert.EqualValues(t, expected, got)
|
||||
})
|
||||
|
||||
t.Run("empty parameter", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equalf(t, http.StatusMethodNotAllowed, rr.Code, "response body: %v", rr.Body)
|
||||
})
|
||||
|
||||
t.Run("UID mismatch", func(t *testing.T) {
|
||||
badMock := mockUserInfo
|
||||
badMock.UID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0}
|
||||
badMarshal, err := json.Marshal(badMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(badMarshal))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
|
||||
})
|
||||
|
||||
t.Run("garbage data", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer([]byte(`{"{{'{;;}}}1`)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
|
||||
})
|
||||
|
||||
t.Run("not base64", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+"defonotbase64", bytes.NewBuffer(marshalled))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
|
||||
})
|
||||
}
|
||||
|
||||
func addUser(t *testing.T, router *APIRouter, user UserInfo) {
|
||||
marshalled, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", "/admin/users/"+base64.URLEncoding.EncodeToString(user.UID), bytes.NewBuffer(marshalled))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
|
||||
}
|
||||
|
||||
func TestGetUserInfoHlr(t *testing.T) {
|
||||
router, cleaner := makeRouter(t)
|
||||
defer cleaner()
|
||||
|
||||
t.Run("empty parameter", func(t *testing.T) {
|
||||
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/", nil)
|
||||
})
|
||||
|
||||
t.Run("non-existent", func(t *testing.T) {
|
||||
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
|
||||
})
|
||||
|
||||
t.Run("not base64", func(t *testing.T) {
|
||||
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+"defonotbase64", nil)
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
addUser(t, router, mockUserInfo)
|
||||
|
||||
var got UserInfo
|
||||
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)), &got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.EqualValues(t, mockUserInfo, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserHlr(t *testing.T) {
|
||||
router, cleaner := makeRouter(t)
|
||||
defer cleaner()
|
||||
|
||||
t.Run("non-existent", func(t *testing.T) {
|
||||
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
|
||||
})
|
||||
|
||||
t.Run("not base64", func(t *testing.T) {
|
||||
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+"defonotbase64", nil)
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
addUser(t, router, mockUserInfo)
|
||||
assert.HTTPSuccess(t, router.ServeHTTP, "DELETE", "/admin/users/"+mockUIDb64, nil)
|
||||
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListAllUsersHlr(t *testing.T) {
|
||||
router, cleaner := makeRouter(t)
|
||||
defer cleaner()
|
||||
|
||||
user1 := mockUserInfo
|
||||
addUser(t, router, user1)
|
||||
|
||||
user2 := mockUserInfo
|
||||
user2.UID = []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
|
||||
addUser(t, router, user2)
|
||||
|
||||
expected := []UserInfo{user1, user2}
|
||||
|
||||
var got []UserInfo
|
||||
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users", nil)), &got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, assert.Subset(t, got, expected), assert.Subset(t, expected, got))
|
||||
}
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
// FIXME: sanity checks. The server may panic due to user input
|
||||
|
||||
// TODO: manual backup
|
||||
|
||||
/*
|
||||
0 reserved
|
||||
1 listActiveUsers none []uids
|
||||
2 listAllUsers none []userinfo
|
||||
3 getUserInfo uid userinfo
|
||||
|
||||
4 addNewUser userinfo ok
|
||||
5 delUser uid ok
|
||||
6 syncMemFromDB uid ok
|
||||
|
||||
7 setSessionsCap uid cap ok
|
||||
8 setUpRate uid rate ok
|
||||
9 setDownRate uid rate ok
|
||||
10 setUpCredit uid credit ok
|
||||
11 setDownCredit uid credit ok
|
||||
12 setExpiryTime uid time ok
|
||||
13 addUpCredit uid delta ok
|
||||
14 addDownCredit uid delta ok
|
||||
*/
|
||||
|
||||
type controller struct {
|
||||
*Userpanel
|
||||
adminUID []byte
|
||||
}
|
||||
|
||||
func (up *Userpanel) MakeController(adminUID []byte) *controller {
|
||||
return &controller{up, adminUID}
|
||||
}
|
||||
|
||||
var errInvalidArgument = errors.New("Invalid argument format")
|
||||
|
||||
func (c *controller) HandleRequest(req []byte) (resp []byte, err error) {
|
||||
check := func(err error) []byte {
|
||||
if err != nil {
|
||||
return c.respond([]byte(err.Error()))
|
||||
} else {
|
||||
return c.respond([]byte("ok"))
|
||||
}
|
||||
}
|
||||
plain, err := c.checkAndDecrypt(req)
|
||||
if err == ErrInvalidMac {
|
||||
log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\naUID:%x\nraw request:\n%x\ndecrypted msg:\n%x", c.adminUID, req, plain)
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Printf("aUID:%x\n,err:%v\n", c.adminUID, err)
|
||||
return c.respond([]byte(err.Error())), nil
|
||||
}
|
||||
|
||||
typ := plain[0]
|
||||
var arg []byte
|
||||
if len(plain) > 1 {
|
||||
arg = plain[1:]
|
||||
}
|
||||
switch typ {
|
||||
case 1:
|
||||
UIDs := c.listActiveUsers()
|
||||
resp, _ = json.Marshal(UIDs)
|
||||
resp = c.respond(resp)
|
||||
case 2:
|
||||
uinfos := c.listAllUsers()
|
||||
resp, _ = json.Marshal(uinfos)
|
||||
resp = c.respond(resp)
|
||||
case 3:
|
||||
uinfo, err := c.getUserInfo(arg)
|
||||
if err != nil {
|
||||
resp = c.respond([]byte(err.Error()))
|
||||
break
|
||||
}
|
||||
resp, _ = json.Marshal(uinfo)
|
||||
resp = c.respond(resp)
|
||||
case 4:
|
||||
var uinfo UserInfo
|
||||
err = json.Unmarshal(arg, &uinfo)
|
||||
if err != nil {
|
||||
resp = c.respond([]byte(err.Error()))
|
||||
break
|
||||
}
|
||||
|
||||
err = c.addNewUser(uinfo)
|
||||
resp = check(err)
|
||||
case 5:
|
||||
err = c.delUser(arg)
|
||||
resp = check(err)
|
||||
case 6:
|
||||
err = c.syncMemFromDB(arg)
|
||||
resp = check(err)
|
||||
case 7:
|
||||
if len(arg) < 36 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setSessionsCap(arg[0:32], Uint32(arg[32:36]))
|
||||
resp = check(err)
|
||||
case 8:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setUpRate(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 9:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setDownRate(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 10:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 11:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 12:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setExpiryTime(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 13:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.addUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
case 14:
|
||||
if len(arg) < 40 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.addDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
|
||||
resp = check(err)
|
||||
default:
|
||||
return c.respond([]byte("Unsupported action")), nil
|
||||
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
var ErrInvalidMac = errors.New("Mac mismatch")
|
||||
var errMsgTooShort = errors.New("Message length is less than 54")
|
||||
|
||||
// protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes]
|
||||
func (c *controller) respond(resp []byte) []byte {
|
||||
respLen := len(resp)
|
||||
|
||||
buf := make([]byte, 5+16+respLen+32)
|
||||
buf[0] = 0x17
|
||||
buf[1] = 0x03
|
||||
buf[2] = 0x03
|
||||
PutUint16(buf[3:5], uint16(16+respLen+32))
|
||||
|
||||
rand.Read(buf[5:21]) //iv
|
||||
copy(buf[21:], resp)
|
||||
block, _ := aes.NewCipher(c.adminUID[0:16])
|
||||
stream := cipher.NewCTR(block, buf[5:21])
|
||||
stream.XORKeyStream(buf[21:21+respLen], buf[21:21+respLen])
|
||||
|
||||
mac := hmac.New(sha256.New, c.adminUID[16:32])
|
||||
mac.Write(buf[5 : 21+respLen])
|
||||
copy(buf[21+respLen:], mac.Sum(nil))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) {
|
||||
if len(data) < 54 {
|
||||
return nil, errMsgTooShort
|
||||
}
|
||||
macIndex := len(data) - 32
|
||||
mac := hmac.New(sha256.New, c.adminUID[16:32])
|
||||
mac.Write(data[5:macIndex])
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(data[macIndex:], expected) {
|
||||
return nil, ErrInvalidMac
|
||||
}
|
||||
|
||||
iv := data[5:21]
|
||||
ret := data[21:macIndex]
|
||||
block, _ := aes.NewCipher(c.adminUID[0:16])
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(ret, ret)
|
||||
return ret, nil
|
||||
}
|
||||
|
|
@ -1,269 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
var u32 = binary.BigEndian.Uint32
|
||||
var u64 = binary.BigEndian.Uint64
|
||||
|
||||
func i64ToB(value int64) []byte {
|
||||
oct := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(oct, uint64(value))
|
||||
return oct
|
||||
}
|
||||
func i32ToB(value int32) []byte {
|
||||
nib := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(nib, uint32(value))
|
||||
return nib
|
||||
}
|
||||
|
||||
// localManager is responsible for managing the local user database
|
||||
type localManager struct {
|
||||
db *bolt.DB
|
||||
world common.WorldState
|
||||
}
|
||||
|
||||
func MakeLocalManager(dbPath string, worldState common.WorldState) (*localManager, error) {
|
||||
db, err := bolt.Open(dbPath, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := &localManager{
|
||||
db: db,
|
||||
world: worldState,
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Authenticate user returns err==nil along with the users' up and down bandwidths if the UID is allowed to connect
|
||||
// More specifically it checks that the user exists, that it has positive credit and that it hasn't expired
|
||||
func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) {
|
||||
var upRate, downRate, upCredit, downCredit, expiryTime int64
|
||||
err := manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(UID)
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
upRate = int64(u64(bucket.Get([]byte("UpRate"))))
|
||||
downRate = int64(u64(bucket.Get([]byte("DownRate"))))
|
||||
upCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
|
||||
downCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
|
||||
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if upCredit <= 0 {
|
||||
return 0, 0, ErrNoUpCredit
|
||||
}
|
||||
if downCredit <= 0 {
|
||||
return 0, 0, ErrNoDownCredit
|
||||
}
|
||||
if expiryTime < manager.world.Now().Unix() {
|
||||
return 0, 0, ErrUserExpired
|
||||
}
|
||||
|
||||
return upRate, downRate, nil
|
||||
}
|
||||
|
||||
// AuthoriseNewSession returns err==nil when the user is allowed to make a new session
|
||||
// More specifically it checks that the user exists, has credit, hasn't expired and hasn't reached sessionsCap
|
||||
func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo AuthorisationInfo) error {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
var sessionsCap int
|
||||
var upCredit, downCredit, expiryTime int64
|
||||
err := manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(arrUID[:])
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
sessionsCap = int(u32(bucket.Get([]byte("SessionsCap"))))
|
||||
upCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
|
||||
downCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
|
||||
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if upCredit <= 0 {
|
||||
return ErrNoUpCredit
|
||||
}
|
||||
if downCredit <= 0 {
|
||||
return ErrNoDownCredit
|
||||
}
|
||||
if expiryTime < manager.world.Now().Unix() {
|
||||
return ErrUserExpired
|
||||
}
|
||||
|
||||
if ainfo.NumExistingSessions >= sessionsCap {
|
||||
return ErrSessionsCapReached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadStatus gets StatusUpdates representing the recent status of each user, and update them in the database
|
||||
// it returns a slice of StatusResponse, which represents actions need to be taken for specific users.
|
||||
// If no action is needed, there won't be a StatusResponse entry for that user
|
||||
func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusResponse, error) {
|
||||
var responses []StatusResponse
|
||||
if len(uploads) == 0 {
|
||||
return responses, nil
|
||||
}
|
||||
err := manager.db.Update(func(tx *bolt.Tx) error {
|
||||
for _, status := range uploads {
|
||||
var resp StatusResponse
|
||||
bucket := tx.Bucket(status.UID)
|
||||
if bucket == nil {
|
||||
resp = StatusResponse{
|
||||
status.UID,
|
||||
TERMINATE,
|
||||
"User no longer exists",
|
||||
}
|
||||
responses = append(responses, resp)
|
||||
continue
|
||||
}
|
||||
|
||||
oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
|
||||
newUp := oldUp - status.UpUsage
|
||||
if newUp <= 0 {
|
||||
resp = StatusResponse{
|
||||
status.UID,
|
||||
TERMINATE,
|
||||
"No upload credit left",
|
||||
}
|
||||
responses = append(responses, resp)
|
||||
}
|
||||
err := bucket.Put([]byte("UpCredit"), i64ToB(newUp))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
oldDown := int64(u64(bucket.Get([]byte("DownCredit"))))
|
||||
newDown := oldDown - status.DownUsage
|
||||
if newDown <= 0 {
|
||||
resp = StatusResponse{
|
||||
status.UID,
|
||||
TERMINATE,
|
||||
"No download credit left",
|
||||
}
|
||||
responses = append(responses, resp)
|
||||
}
|
||||
err = bucket.Put([]byte("DownCredit"), i64ToB(newDown))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
expiry := int64(u64(bucket.Get([]byte("ExpiryTime"))))
|
||||
if manager.world.Now().Unix() > expiry {
|
||||
resp = StatusResponse{
|
||||
status.UID,
|
||||
TERMINATE,
|
||||
"User has expired",
|
||||
}
|
||||
responses = append(responses, resp)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return responses, err
|
||||
}
|
||||
|
||||
func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
|
||||
err = manager.db.View(func(tx *bolt.Tx) error {
|
||||
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
|
||||
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
|
||||
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
|
||||
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
|
||||
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
|
||||
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
|
||||
infos = append(infos, uinfo)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
})
|
||||
if infos == nil {
|
||||
infos = []UserInfo{}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) {
|
||||
err = manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(UID)
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
|
||||
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
|
||||
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
|
||||
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
|
||||
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
|
||||
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
|
||||
err = manager.db.Update(func(tx *bolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists(u.UID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.SessionsCap != nil {
|
||||
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u.UpRate != nil {
|
||||
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u.DownRate != nil {
|
||||
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u.UpCredit != nil {
|
||||
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u.DownCredit != nil {
|
||||
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u.ExpiryTime != nil {
|
||||
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (manager *localManager) DeleteUser(UID []byte) (err error) {
|
||||
err = manager.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.DeleteBucket(UID)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (manager *localManager) Close() error {
|
||||
return manager.db.Close()
|
||||
}
|
||||
|
|
@ -1,365 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
||||
var mockUserInfo = UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: JustInt32(10),
|
||||
UpRate: JustInt64(100),
|
||||
DownRate: JustInt64(1000),
|
||||
UpCredit: JustInt64(10000),
|
||||
DownCredit: JustInt64(100000),
|
||||
ExpiryTime: JustInt64(1000000),
|
||||
}
|
||||
|
||||
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
cleaner = func() { os.Remove(tmpDB.Name()) }
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return mgr, cleaner
|
||||
}
|
||||
|
||||
func TestLocalManager_WriteUserInfo(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
err := mgr.WriteUserInfo(mockUserInfo)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
got, err := mgr.GetUserInfo(mockUID)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, mockUserInfo, got)
|
||||
|
||||
/* Partial update */
|
||||
err = mgr.WriteUserInfo(UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := mockUserInfo
|
||||
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
|
||||
got, err = mgr.GetUserInfo(mockUID)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, expected, got)
|
||||
}
|
||||
|
||||
func TestLocalManager_GetUserInfo(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
t.Run("simple fetch", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
gotInfo, err := mgr.GetUserInfo(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotInfo, mockUserInfo) {
|
||||
t.Errorf("got wrong user info: %v", gotInfo)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("update a field", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
updatedUserInfo := mockUserInfo
|
||||
updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
|
||||
|
||||
err := mgr.WriteUserInfo(updatedUserInfo)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
gotInfo, err := mgr.GetUserInfo(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotInfo, updatedUserInfo) {
|
||||
t.Errorf("got wrong user info: %v", updatedUserInfo)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
_, err := mgr.GetUserInfo(make([]byte, 16))
|
||||
if err != ErrUserNotFound {
|
||||
t.Errorf("expecting error %v, got %v", ErrUserNotFound, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalManager_DeleteUser(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||
err := mgr.DeleteUser(mockUID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
_, err = mgr.GetUserInfo(mockUID)
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
var validUserInfo = mockUserInfo
|
||||
|
||||
func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("normal auth", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
upRate, downRate, err := mgr.AuthenticateUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
|
||||
t.Error("wrong up or down rate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
_, _, err := mgr.AuthenticateUser(make([]byte, 16))
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired user", func(t *testing.T) {
|
||||
expiredUserInfo := validUserInfo
|
||||
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
|
||||
|
||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||
|
||||
_, _, err := mgr.AuthenticateUser(expiredUserInfo.UID)
|
||||
if err != ErrUserExpired {
|
||||
t.Error("user not expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no credit", func(t *testing.T) {
|
||||
creditlessUserInfo := validUserInfo
|
||||
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
|
||||
|
||||
_ = mgr.WriteUserInfo(creditlessUserInfo)
|
||||
|
||||
_, _, err := mgr.AuthenticateUser(creditlessUserInfo.UID)
|
||||
if err != ErrNoUpCredit && err != ErrNoDownCredit {
|
||||
t.Error("user not creditless")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalManager_AuthoriseNewSession(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
t.Run("normal auth", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
err := mgr.AuthoriseNewSession(make([]byte, 16), AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != ErrUserNotFound {
|
||||
t.Error("user found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired user", func(t *testing.T) {
|
||||
expiredUserInfo := validUserInfo
|
||||
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
|
||||
|
||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
||||
if err != ErrUserExpired {
|
||||
t.Error("user not expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too many sessions", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
|
||||
if err != ErrSessionsCapReached {
|
||||
t.Error("session cap not reached")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalManager_UploadStatus(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
t.Run("simple update", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
|
||||
update := StatusUpdate{
|
||||
UID: validUserInfo.UID,
|
||||
Active: true,
|
||||
NumSession: 1,
|
||||
UpUsage: 10,
|
||||
DownUsage: 100,
|
||||
Timestamp: mockWorldState.Now().Unix(),
|
||||
}
|
||||
|
||||
_, err := mgr.UploadStatus([]StatusUpdate{update})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedUserInfo, err := mgr.GetUserInfo(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
|
||||
t.Error("up usage incorrect")
|
||||
}
|
||||
if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
|
||||
t.Error("down usage incorrect")
|
||||
}
|
||||
})
|
||||
|
||||
badUpdates := []struct {
|
||||
name string
|
||||
user UserInfo
|
||||
update StatusUpdate
|
||||
}{
|
||||
{"out of up credit",
|
||||
validUserInfo,
|
||||
StatusUpdate{
|
||||
UID: validUserInfo.UID,
|
||||
Active: true,
|
||||
NumSession: 1,
|
||||
UpUsage: *validUserInfo.UpCredit + 100,
|
||||
DownUsage: 0,
|
||||
Timestamp: mockWorldState.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{"out of down credit",
|
||||
validUserInfo,
|
||||
StatusUpdate{
|
||||
UID: validUserInfo.UID,
|
||||
Active: true,
|
||||
NumSession: 1,
|
||||
UpUsage: 0,
|
||||
DownUsage: *validUserInfo.DownCredit + 100,
|
||||
Timestamp: mockWorldState.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{"expired",
|
||||
UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: JustInt32(10),
|
||||
UpRate: JustInt64(0),
|
||||
DownRate: JustInt64(0),
|
||||
UpCredit: JustInt64(0),
|
||||
DownCredit: JustInt64(0),
|
||||
ExpiryTime: JustInt64(-1),
|
||||
},
|
||||
StatusUpdate{
|
||||
UID: mockUserInfo.UID,
|
||||
Active: true,
|
||||
NumSession: 1,
|
||||
UpUsage: 0,
|
||||
DownUsage: 0,
|
||||
Timestamp: mockWorldState.Now().Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, badUpdate := range badUpdates {
|
||||
t.Run(badUpdate.name, func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(badUpdate.user)
|
||||
resps, err := mgr.UploadStatus([]StatusUpdate{badUpdate.update})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(resps) == 0 {
|
||||
t.Fatal("expecting responses")
|
||||
}
|
||||
|
||||
resp := resps[0]
|
||||
if resp.Action != TERMINATE {
|
||||
t.Errorf("didn't terminate when %v", badUpdate.name)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalManager_ListAllUsers(t *testing.T) {
|
||||
mgr, cleaner := makeManager(t)
|
||||
defer cleaner()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var users []UserInfo
|
||||
for i := 0; i < 100; i++ {
|
||||
randUID := make([]byte, 16)
|
||||
rand.Read(randUID)
|
||||
newUser := UserInfo{
|
||||
UID: randUID,
|
||||
SessionsCap: JustInt32(rand.Int31()),
|
||||
UpRate: JustInt64(rand.Int63()),
|
||||
DownRate: JustInt64(rand.Int63()),
|
||||
UpCredit: JustInt64(rand.Int63()),
|
||||
DownCredit: JustInt64(rand.Int63()),
|
||||
ExpiryTime: JustInt64(rand.Int63()),
|
||||
}
|
||||
users = append(users, newUser)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
err := mgr.WriteUserInfo(newUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
listedUsers, err := mgr.ListAllUsers()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return binary.BigEndian.Uint64(users[i].UID[0:8]) < binary.BigEndian.Uint64(users[j].UID[0:8])
|
||||
})
|
||||
sort.Slice(listedUsers, func(i, j int) bool {
|
||||
return binary.BigEndian.Uint64(listedUsers[i].UID[0:8]) < binary.BigEndian.Uint64(listedUsers[j].UID[0:8])
|
||||
})
|
||||
if !reflect.DeepEqual(users, listedUsers) {
|
||||
t.Error("listed users deviates from uploaded ones")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
// for the ease of using json package
|
||||
type UserInfo struct {
|
||||
UID []byte
|
||||
// ALL of the following fields have to be accessed atomically
|
||||
SessionsCap uint32
|
||||
UpRate int64
|
||||
DownRate int64
|
||||
UpCredit int64
|
||||
DownCredit int64
|
||||
ExpiryTime int64
|
||||
}
|
||||
|
||||
type User struct {
|
||||
up *Userpanel
|
||||
|
||||
arrUID [32]byte
|
||||
|
||||
*UserInfo
|
||||
|
||||
valve *mux.Valve
|
||||
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[uint32]*mux.Session
|
||||
}
|
||||
|
||||
func MakeUser(up *Userpanel, uinfo *UserInfo) *User {
|
||||
// this instance of valve is shared across ALL sessions of a user
|
||||
valve := mux.MakeValve(uinfo.UpRate, uinfo.DownRate, &uinfo.UpCredit, &uinfo.DownCredit)
|
||||
u := &User{
|
||||
up: up,
|
||||
UserInfo: uinfo,
|
||||
valve: valve,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
}
|
||||
copy(u.arrUID[:], uinfo.UID)
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *User) addUpCredit(delta int64) { u.valve.AddRxCredit(delta) }
|
||||
func (u *User) addDownCredit(delta int64) { u.valve.AddTxCredit(delta) }
|
||||
func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.SessionsCap, cap) }
|
||||
func (u *User) setUpRate(rate int64) { u.valve.SetRxRate(rate) }
|
||||
func (u *User) setDownRate(rate int64) { u.valve.SetTxRate(rate) }
|
||||
func (u *User) setUpCredit(n int64) { u.valve.SetRxCredit(n) }
|
||||
func (u *User) setDownCredit(n int64) { u.valve.SetTxCredit(n) }
|
||||
func (u *User) setExpiryTime(time int64) { atomic.StoreInt64(&u.ExpiryTime, time) }
|
||||
|
||||
func (u *User) updateInfo(uinfo UserInfo) {
|
||||
u.setSessionsCap(uinfo.SessionsCap)
|
||||
u.setUpCredit(uinfo.UpCredit)
|
||||
u.setDownCredit(uinfo.DownCredit)
|
||||
u.setUpRate(uinfo.UpRate)
|
||||
u.setDownRate(uinfo.DownRate)
|
||||
u.setExpiryTime(uinfo.ExpiryTime)
|
||||
}
|
||||
|
||||
func (u *User) DelSession(sessionID uint32) {
|
||||
u.sessionsM.Lock()
|
||||
delete(u.sessions, sessionID)
|
||||
if len(u.sessions) == 0 {
|
||||
u.sessionsM.Unlock()
|
||||
u.up.delActiveUser(u.UID)
|
||||
return
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (u *User) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
|
||||
if time.Now().Unix() > u.ExpiryTime {
|
||||
return nil, false, errors.New("Expiry time passed")
|
||||
}
|
||||
u.sessionsM.Lock()
|
||||
if sesh = u.sessions[sessionID]; sesh != nil {
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, true, nil
|
||||
} else {
|
||||
if len(u.sessions) >= int(u.SessionsCap) {
|
||||
u.sessionsM.Unlock()
|
||||
return nil, false, errors.New("SessionsCap reached")
|
||||
}
|
||||
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
|
||||
u.sessions[sessionID] = sesh
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, false, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type StatusUpdate struct {
|
||||
UID []byte
|
||||
Active bool
|
||||
NumSession int
|
||||
|
||||
UpUsage int64
|
||||
DownUsage int64
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type MaybeInt32 *int32
|
||||
type MaybeInt64 *int64
|
||||
|
||||
type UserInfo struct {
|
||||
UID []byte
|
||||
SessionsCap MaybeInt32
|
||||
UpRate MaybeInt64
|
||||
DownRate MaybeInt64
|
||||
UpCredit MaybeInt64
|
||||
DownCredit MaybeInt64
|
||||
ExpiryTime MaybeInt64
|
||||
}
|
||||
|
||||
func JustInt32(v int32) MaybeInt32 { return &v }
|
||||
|
||||
func JustInt64(v int64) MaybeInt64 { return &v }
|
||||
|
||||
type StatusResponse struct {
|
||||
UID []byte
|
||||
Action int
|
||||
Message string
|
||||
}
|
||||
|
||||
type AuthorisationInfo struct {
|
||||
NumExistingSessions int
|
||||
}
|
||||
|
||||
const (
|
||||
TERMINATE = iota + 1
|
||||
)
|
||||
|
||||
var ErrUserNotFound = errors.New("UID does not correspond to a user")
|
||||
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
|
||||
var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified")
|
||||
|
||||
var ErrNoUpCredit = errors.New("No upload credit left")
|
||||
var ErrNoDownCredit = errors.New("No download credit left")
|
||||
var ErrUserExpired = errors.New("User has expired")
|
||||
|
||||
type UserManager interface {
|
||||
AuthenticateUser([]byte) (int64, int64, error)
|
||||
AuthoriseNewSession([]byte, AuthorisationInfo) error
|
||||
UploadStatus([]StatusUpdate) ([]StatusResponse, error)
|
||||
ListAllUsers() ([]UserInfo, error)
|
||||
GetUserInfo(UID []byte) (UserInfo, error)
|
||||
WriteUserInfo(UserInfo) error
|
||||
DeleteUser(UID []byte) error
|
||||
}
|
||||
|
|
@ -0,0 +1,472 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
var Uint32 = binary.BigEndian.Uint32
|
||||
var Uint64 = binary.BigEndian.Uint64
|
||||
var PutUint16 = binary.BigEndian.PutUint16
|
||||
var PutUint32 = binary.BigEndian.PutUint32
|
||||
var PutUint64 = binary.BigEndian.PutUint64
|
||||
|
||||
type Userpanel struct {
|
||||
db *bolt.DB
|
||||
bakRoot string
|
||||
|
||||
activeUsersM sync.RWMutex
|
||||
activeUsers map[[32]byte]*User
|
||||
}
|
||||
|
||||
func MakeUserpanel(dbPath, bakRoot string) (*Userpanel, error) {
|
||||
db, err := bolt.Open(dbPath, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bakRoot == "" {
|
||||
os.Mkdir("db-backup", 0777)
|
||||
bakRoot = "db-backup"
|
||||
}
|
||||
bakRoot = path.Clean(bakRoot)
|
||||
up := &Userpanel{
|
||||
db: db,
|
||||
bakRoot: bakRoot,
|
||||
activeUsers: make(map[[32]byte]*User),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Second * 10)
|
||||
up.updateCredits()
|
||||
}
|
||||
}()
|
||||
return up, nil
|
||||
}
|
||||
|
||||
// credits of all users are updated together so that there is only 1 goroutine managing it
|
||||
func (up *Userpanel) updateCredits() {
|
||||
up.activeUsersM.RLock()
|
||||
for _, u := range up.activeUsers {
|
||||
up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(u.arrUID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte("UpCredit"), i64ToB(u.valve.GetRxCredit())); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Put([]byte("DownCredit"), i64ToB(u.valve.GetTxCredit())); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
})
|
||||
}
|
||||
up.activeUsersM.RUnlock()
|
||||
|
||||
}
|
||||
|
||||
func (up *Userpanel) backupDB(bakFileName string) error {
|
||||
bakPath := up.bakRoot + "/" + bakFileName
|
||||
_, err := os.Stat(bakPath)
|
||||
if err == nil {
|
||||
return errors.New("Attempting to overwrite a file during backup!")
|
||||
}
|
||||
var bak *os.File
|
||||
if os.IsNotExist(err) {
|
||||
bak, err = os.Create(bakPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = up.db.View(func(tx *bolt.Tx) error {
|
||||
_, err := tx.WriteTo(bak)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var ErrUserNotFound = errors.New("User does not exist in db")
|
||||
var ErrUserNotActive = errors.New("User is not active")
|
||||
|
||||
func (up *Userpanel) GetAndActivateAdminUser(AdminUID []byte) (*User, error) {
|
||||
up.activeUsersM.Lock()
|
||||
var arrUID [32]byte
|
||||
copy(arrUID[:], AdminUID)
|
||||
if user, ok := up.activeUsers[arrUID]; ok {
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
uinfo := UserInfo{
|
||||
UID: AdminUID,
|
||||
SessionsCap: 1e9,
|
||||
UpRate: 1e12,
|
||||
DownRate: 1e12,
|
||||
UpCredit: 1e15,
|
||||
DownCredit: 1e15,
|
||||
ExpiryTime: 1e15,
|
||||
}
|
||||
|
||||
user := MakeUser(up, &uinfo)
|
||||
up.activeUsers[arrUID] = user
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's info
|
||||
// from the db and mark it as an active user
|
||||
func (up *Userpanel) GetAndActivateUser(UID []byte) (*User, error) {
|
||||
up.activeUsersM.Lock()
|
||||
var arrUID [32]byte
|
||||
copy(arrUID[:], UID)
|
||||
if user, ok := up.activeUsers[arrUID]; ok {
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) // reee brackets
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
up.activeUsersM.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
u := MakeUser(up, &uinfo)
|
||||
up.activeUsers[arrUID] = u
|
||||
up.activeUsersM.Unlock()
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryUint32(UID []byte, key string, value uint32) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte(key), u32ToB(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryInt64(UID []byte, key string, value int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte(key), i64ToB(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// This is used when all sessions of a user close
|
||||
func (up *Userpanel) delActiveUser(UID []byte) {
|
||||
var arrUID [32]byte
|
||||
copy(arrUID[:], UID)
|
||||
up.activeUsersM.Lock()
|
||||
delete(up.activeUsers, arrUID)
|
||||
up.activeUsersM.Unlock()
|
||||
}
|
||||
|
||||
func (up *Userpanel) getActiveUser(UID []byte) *User {
|
||||
var arrUID [32]byte
|
||||
copy(arrUID[:], UID)
|
||||
up.activeUsersM.RLock()
|
||||
ret := up.activeUsers[arrUID]
|
||||
up.activeUsersM.RUnlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
// below are remote control utilised functions
|
||||
|
||||
func (up *Userpanel) listActiveUsers() [][]byte {
|
||||
var ret [][]byte
|
||||
up.activeUsersM.RLock()
|
||||
for _, u := range up.activeUsers {
|
||||
ret = append(ret, u.UID)
|
||||
}
|
||||
up.activeUsersM.RUnlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (up *Userpanel) listAllUsers() []UserInfo {
|
||||
var ret []UserInfo
|
||||
up.db.View(func(tx *bolt.Tx) error {
|
||||
tx.ForEach(func(UID []byte, b *bolt.Bucket) error {
|
||||
// if we want to avoid writing every single key out,
|
||||
// we would have to either make UserInfo a map,
|
||||
// or use reflect.
|
||||
// neither is convinient
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
ret = append(ret, uinfo)
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
func (up *Userpanel) getUserInfo(UID []byte) (UserInfo, error) {
|
||||
var uinfo UserInfo
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
return uinfo, err
|
||||
}
|
||||
|
||||
// In boltdb, the value argument for bucket.Put has to be valid for the duration
|
||||
// of the transaction.
|
||||
// This basically means that you cannot reuse a byte slice for two different keys
|
||||
// in a transaction. So we need to allocate a fresh byte slice for each value
|
||||
func u32ToB(value uint32) []byte {
|
||||
quad := make([]byte, 4)
|
||||
PutUint32(quad, value)
|
||||
return quad
|
||||
}
|
||||
|
||||
func i64ToB(value int64) []byte {
|
||||
oct := make([]byte, 8)
|
||||
PutUint64(oct, uint64(value))
|
||||
return oct
|
||||
}
|
||||
|
||||
func (up *Userpanel) addNewUser(uinfo UserInfo) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket(uinfo.UID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("SessionsCap"), u32ToB(uinfo.SessionsCap)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) delUser(UID []byte) error {
|
||||
err := up.backupDB(strconv.FormatInt(time.Now().Unix(), 10) + ".bak")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = up.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.DeleteBucket(UID)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) syncMemFromDB(UID []byte) error {
|
||||
var uinfo UserInfo
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return ErrUserNotActive
|
||||
}
|
||||
u.updateInfo(uinfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// the following functions will update the db entries first, then if the
|
||||
// user is active, it will update it in memory.
|
||||
|
||||
func (up *Userpanel) setSessionsCap(UID []byte, cap uint32) error {
|
||||
err := up.updateDBEntryUint32(UID, "SessionsCap", cap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setSessionsCap(cap)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) setUpRate(UID []byte, rate int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "UpRate", rate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setUpRate(rate)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setDownRate(UID []byte, rate int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "DownRate", rate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setDownRate(rate)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setUpCredit(UID []byte, n int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "UpCredit", n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setUpCredit(n)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setDownCredit(UID []byte, n int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "DownCredit", n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setDownCredit(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) setExpiryTime(UID []byte, time int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "ExpiryTime", time)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setExpiryTime(time)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) addUpCredit(UID []byte, delta int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
old := b.Get([]byte("UpCredit"))
|
||||
new := int64(Uint64(old)) + delta
|
||||
if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.addUpCredit(delta)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) addDownCredit(UID []byte, delta int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
old := b.Get([]byte("DownCredit"))
|
||||
new := int64(Uint64(old)) + delta
|
||||
if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.addDownCredit(delta)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
type Voidmanager struct{}
|
||||
|
||||
func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) {
|
||||
return 0, 0, ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error {
|
||||
return ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) {
|
||||
return nil, ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) {
|
||||
return []UserInfo{}, ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) {
|
||||
return UserInfo{}, ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) WriteUserInfo(info UserInfo) error {
|
||||
return ErrMangerIsVoid
|
||||
}
|
||||
|
||||
func (v *Voidmanager) DeleteUser(UID []byte) error {
|
||||
return ErrMangerIsVoid
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var v = &Voidmanager{}
|
||||
|
||||
func Test_Voidmanager_AuthenticateUser(t *testing.T) {
|
||||
_, _, err := v.AuthenticateUser([]byte{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_AuthoriseNewSession(t *testing.T) {
|
||||
err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_DeleteUser(t *testing.T) {
|
||||
err := v.DeleteUser([]byte{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_GetUserInfo(t *testing.T) {
|
||||
_, err := v.GetUserInfo([]byte{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_ListAllUsers(t *testing.T) {
|
||||
_, err := v.ListAllUsers()
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_UploadStatus(t *testing.T) {
|
||||
_, err := v.UploadStatus([]StatusUpdate{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
||||
func Test_Voidmanager_WriteUserInfo(t *testing.T) {
|
||||
err := v.WriteUserInfo(UserInfo{})
|
||||
assert.Equal(t, ErrMangerIsVoid, err)
|
||||
}
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultUploadInterval = 1 * time.Minute
|
||||
|
||||
// userPanel is used to authenticate new users and book keep active users
|
||||
type userPanel struct {
|
||||
Manager usermanager.UserManager
|
||||
|
||||
activeUsersM sync.RWMutex
|
||||
activeUsers map[[16]byte]*ActiveUser
|
||||
usageUpdateQueueM sync.Mutex
|
||||
usageUpdateQueue map[[16]byte]*usagePair
|
||||
|
||||
uploadInterval time.Duration
|
||||
}
|
||||
|
||||
func MakeUserPanel(manager usermanager.UserManager) *userPanel {
|
||||
ret := &userPanel{
|
||||
Manager: manager,
|
||||
activeUsers: make(map[[16]byte]*ActiveUser),
|
||||
usageUpdateQueue: make(map[[16]byte]*usagePair),
|
||||
uploadInterval: defaultUploadInterval,
|
||||
}
|
||||
go ret.regularQueueUpload()
|
||||
return ret
|
||||
}
|
||||
|
||||
// GetBypassUser does the same as GetUser except it unconditionally creates an ActiveUser when the UID isn't already active
|
||||
func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) {
|
||||
panel.activeUsersM.Lock()
|
||||
defer panel.activeUsersM.Unlock()
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
if user, ok := panel.activeUsers[arrUID]; ok {
|
||||
return user, nil
|
||||
}
|
||||
user := &ActiveUser{
|
||||
panel: panel,
|
||||
valve: mux.UNLIMITED_VALVE,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
bypass: true,
|
||||
}
|
||||
copy(user.arrUID[:], UID)
|
||||
panel.activeUsers[user.arrUID] = user
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUser retrieves the reference to an ActiveUser if it's already active, or creates a new ActiveUser of specified
|
||||
// UID with UserInfo queried from the UserManger, should the particular UID is allowed to connect
|
||||
func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
|
||||
panel.activeUsersM.Lock()
|
||||
defer panel.activeUsersM.Unlock()
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
if user, ok := panel.activeUsers[arrUID]; ok {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
upRate, downRate, err := panel.Manager.AuthenticateUser(UID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valve := mux.MakeValve(upRate, downRate)
|
||||
user := &ActiveUser{
|
||||
panel: panel,
|
||||
valve: valve,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
}
|
||||
|
||||
copy(user.arrUID[:], UID)
|
||||
panel.activeUsers[user.arrUID] = user
|
||||
log.WithFields(log.Fields{
|
||||
"UID": base64.StdEncoding.EncodeToString(UID),
|
||||
}).Info("New active user")
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// TerminateActiveUser terminates a user and deletes its references
|
||||
func (panel *userPanel) TerminateActiveUser(user *ActiveUser, reason string) {
|
||||
log.WithFields(log.Fields{
|
||||
"UID": base64.StdEncoding.EncodeToString(user.arrUID[:]),
|
||||
"reason": reason,
|
||||
}).Info("Terminating active user")
|
||||
panel.updateUsageQueueForOne(user)
|
||||
user.closeAllSessions(reason)
|
||||
panel.activeUsersM.Lock()
|
||||
delete(panel.activeUsers, user.arrUID)
|
||||
panel.activeUsersM.Unlock()
|
||||
}
|
||||
|
||||
func (panel *userPanel) isActive(UID []byte) bool {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
panel.activeUsersM.RLock()
|
||||
_, ok := panel.activeUsers[arrUID]
|
||||
panel.activeUsersM.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
type usagePair struct {
|
||||
up *int64
|
||||
down *int64
|
||||
}
|
||||
|
||||
// updateUsageQueue zeroes the accumulated usage all ActiveUsers valve and put the usage data im usageUpdateQueue
|
||||
func (panel *userPanel) updateUsageQueue() {
|
||||
panel.activeUsersM.Lock()
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
for _, user := range panel.activeUsers {
|
||||
if user.bypass {
|
||||
continue
|
||||
}
|
||||
|
||||
upIncured, downIncured := user.valve.Nullify()
|
||||
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
|
||||
atomic.AddInt64(usage.up, upIncured)
|
||||
atomic.AddInt64(usage.down, downIncured)
|
||||
} else {
|
||||
// if the user hasn't been added to the queue
|
||||
usage = &usagePair{&upIncured, &downIncured}
|
||||
panel.usageUpdateQueue[user.arrUID] = usage
|
||||
}
|
||||
}
|
||||
panel.activeUsersM.Unlock()
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
}
|
||||
|
||||
// updateUsageQueueForOne is the same as updateUsageQueue except it only updates one user's usage
|
||||
// this is useful when the user is being terminated
|
||||
func (panel *userPanel) updateUsageQueueForOne(user *ActiveUser) {
|
||||
// used when one particular user deactivates
|
||||
if user.bypass {
|
||||
return
|
||||
}
|
||||
upIncured, downIncured := user.valve.Nullify()
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
|
||||
atomic.AddInt64(usage.up, upIncured)
|
||||
atomic.AddInt64(usage.down, downIncured)
|
||||
} else {
|
||||
usage = &usagePair{&upIncured, &downIncured}
|
||||
panel.usageUpdateQueue[user.arrUID] = usage
|
||||
}
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// commitUpdate put all usageUpdates into a slice of StatusUpdate, calls Manager.UploadStatus, gets the responses
|
||||
// and act to each user according to the responses
|
||||
func (panel *userPanel) commitUpdate() error {
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
statuses := make([]usermanager.StatusUpdate, 0, len(panel.usageUpdateQueue))
|
||||
for arrUID, usage := range panel.usageUpdateQueue {
|
||||
panel.activeUsersM.RLock()
|
||||
user := panel.activeUsers[arrUID]
|
||||
panel.activeUsersM.RUnlock()
|
||||
var numSession int
|
||||
if user != nil {
|
||||
if user.bypass {
|
||||
continue
|
||||
}
|
||||
numSession = user.NumSession()
|
||||
}
|
||||
status := usermanager.StatusUpdate{
|
||||
UID: arrUID[:],
|
||||
Active: panel.isActive(arrUID[:]),
|
||||
NumSession: numSession,
|
||||
UpUsage: *usage.up,
|
||||
DownUsage: *usage.down,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
statuses = append(statuses, status)
|
||||
}
|
||||
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil
|
||||
}
|
||||
responses, err := panel.Manager.UploadStatus(statuses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, resp := range responses {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], resp.UID)
|
||||
switch resp.Action {
|
||||
case usermanager.TERMINATE:
|
||||
panel.activeUsersM.RLock()
|
||||
user := panel.activeUsers[arrUID]
|
||||
panel.activeUsersM.RUnlock()
|
||||
if user != nil {
|
||||
panel.TerminateActiveUser(user, resp.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (panel *userPanel) regularQueueUpload() {
|
||||
for {
|
||||
time.Sleep(panel.uploadInterval)
|
||||
go func() {
|
||||
panel.updateUsageQueue()
|
||||
err := panel.commitUpdate()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
)
|
||||
|
||||
func TestUserPanel_BypassUser(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
|
||||
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState)
|
||||
if err != nil {
|
||||
t.Error("failed to make local manager", err)
|
||||
}
|
||||
panel := MakeUserPanel(manager)
|
||||
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
|
||||
user, _ := panel.GetBypassUser(UID)
|
||||
user.valve.AddRx(10)
|
||||
user.valve.AddTx(10)
|
||||
t.Run("isActive", func(t *testing.T) {
|
||||
a := panel.isActive(UID)
|
||||
if !a {
|
||||
t.Error("isActive returned ", a)
|
||||
}
|
||||
})
|
||||
t.Run("updateUsageQueue", func(t *testing.T) {
|
||||
panel.updateUsageQueue()
|
||||
if _, inQ := panel.usageUpdateQueue[user.arrUID]; inQ {
|
||||
t.Error("user in update queue")
|
||||
}
|
||||
})
|
||||
t.Run("updateUsageQueueForOne", func(t *testing.T) {
|
||||
panel.updateUsageQueueForOne(user)
|
||||
if _, inQ := panel.usageUpdateQueue[user.arrUID]; inQ {
|
||||
t.Error("user in update queue")
|
||||
}
|
||||
})
|
||||
t.Run("commitUpdate", func(t *testing.T) {
|
||||
err := panel.commitUpdate()
|
||||
if err != nil {
|
||||
t.Error("commit returned", err)
|
||||
}
|
||||
})
|
||||
t.Run("TerminateActiveUser", func(t *testing.T) {
|
||||
panel.TerminateActiveUser(user, "")
|
||||
if panel.isActive(user.arrUID[:]) {
|
||||
t.Error("user still active after deletion", err)
|
||||
}
|
||||
})
|
||||
t.Run("Repeated delete", func(t *testing.T) {
|
||||
panel.TerminateActiveUser(user, "")
|
||||
})
|
||||
err = manager.Close()
|
||||
if err != nil {
|
||||
t.Error("failed to close localmanager", err)
|
||||
}
|
||||
}
|
||||
|
||||
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
||||
var validUserInfo = usermanager.UserInfo{
|
||||
UID: mockUID,
|
||||
SessionsCap: usermanager.JustInt32(10),
|
||||
UpRate: usermanager.JustInt64(100),
|
||||
DownRate: usermanager.JustInt64(1000),
|
||||
UpCredit: usermanager.JustInt64(10000),
|
||||
DownCredit: usermanager.JustInt64(100000),
|
||||
ExpiryTime: usermanager.JustInt64(1000000),
|
||||
}
|
||||
|
||||
func TestUserPanel_GetUser(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
panel := MakeUserPanel(mgr)
|
||||
|
||||
t.Run("normal user", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
|
||||
activeUser, err := panel.GetUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
again, err := panel.GetUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Errorf("can't get existing user: %v", err)
|
||||
}
|
||||
|
||||
if activeUser != again {
|
||||
t.Error("got different references")
|
||||
}
|
||||
})
|
||||
t.Run("non existent user", func(t *testing.T) {
|
||||
_, err = panel.GetUser(make([]byte, 16))
|
||||
if err != usermanager.ErrUserNotFound {
|
||||
t.Errorf("expecting error %v, got %v", usermanager.ErrUserNotFound, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserPanel_UpdateUsageQueue(t *testing.T) {
|
||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||
defer os.Remove(tmpDB.Name())
|
||||
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
panel := MakeUserPanel(mgr)
|
||||
|
||||
t.Run("normal update", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
|
||||
user, err := panel.GetUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
user.valve.AddTx(1)
|
||||
user.valve.AddRx(2)
|
||||
panel.updateUsageQueue()
|
||||
err = panel.commitUpdate()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if user.valve.GetRx() != 0 || user.valve.GetTx() != 0 {
|
||||
t.Error("rx and tx stats are not cleared")
|
||||
}
|
||||
|
||||
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
||||
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
|
||||
t.Error("down credit incorrect update")
|
||||
}
|
||||
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
|
||||
t.Error("up credit incorrect update")
|
||||
}
|
||||
|
||||
// another update
|
||||
user.valve.AddTx(3)
|
||||
user.valve.AddRx(4)
|
||||
panel.updateUsageQueue()
|
||||
err = panel.commitUpdate()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
|
||||
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
|
||||
t.Error("down credit incorrect update")
|
||||
}
|
||||
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
|
||||
t.Error("up credit incorrect update")
|
||||
}
|
||||
})
|
||||
t.Run("terminating update", func(t *testing.T) {
|
||||
_ = mgr.WriteUserInfo(validUserInfo)
|
||||
|
||||
user, err := panel.GetUser(validUserInfo.UID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
user.valve.AddTx(*validUserInfo.DownCredit + 100)
|
||||
panel.updateUsageQueue()
|
||||
err = panel.commitUpdate()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if panel.isActive(validUserInfo.UID) {
|
||||
t.Error("user not terminated")
|
||||
}
|
||||
|
||||
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
||||
if *updatedUinfo.DownCredit != -100 {
|
||||
t.Error("down credit not updated correctly after the user has been terminated")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||
)
|
||||
|
||||
type WebSocket struct{}
|
||||
|
||||
func (WebSocket) String() string { return "WebSocket" }
|
||||
|
||||
func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
|
||||
var req *http.Request
|
||||
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket)))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
|
||||
return
|
||||
}
|
||||
var hiddenData []byte
|
||||
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
|
||||
|
||||
fragments, err = WebSocket{}.unmarshalHidden(hiddenData, privateKey)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal hidden data from WS into authFragments: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
respond = WebSocket{}.makeResponder(reqPacket, fragments.sharedSecret)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (WebSocket) makeResponder(reqPacket []byte, sharedSecret [32]byte) Responder {
|
||||
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) {
|
||||
handler := newWsHandshakeHandler()
|
||||
|
||||
// For an explanation of the following 3 lines, see the comments in websocketAux.go
|
||||
http.Serve(newWsAcceptor(originalConn, reqPacket), handler)
|
||||
|
||||
<-handler.finished
|
||||
preparedConn = handler.conn
|
||||
nonce := make([]byte, 12)
|
||||
common.RandRead(randSource, nonce)
|
||||
|
||||
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
|
||||
encryptedKey, err := common.AESGCMEncrypt(nonce, sharedSecret[:], sessionKey[:]) // 32 + 16 = 48 bytes
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encrypt reply: %v", err)
|
||||
return
|
||||
}
|
||||
reply := append(nonce, encryptedKey...)
|
||||
_, err = preparedConn.Write(reply)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to write reply: %v", err)
|
||||
preparedConn.Close()
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
return respond
|
||||
}
|
||||
|
||||
var ErrBadGET = errors.New("non (or malformed) HTTP GET")
|
||||
|
||||
func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fragments authFragments, err error) {
|
||||
if len(hidden) < 96 {
|
||||
err = ErrBadGET
|
||||
return
|
||||
}
|
||||
|
||||
copy(fragments.randPubKey[:], hidden[0:32])
|
||||
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:])
|
||||
if !ok {
|
||||
err = ErrInvalidPubKey
|
||||
return
|
||||
}
|
||||
|
||||
var sharedSecret []byte
|
||||
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
copy(fragments.sharedSecret[:], sharedSecret)
|
||||
|
||||
if len(hidden[32:]) != 64 {
|
||||
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))
|
||||
return
|
||||
}
|
||||
|
||||
copy(fragments.ciphertextWithTag[:], hidden[32:])
|
||||
return
|
||||
}
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
|
||||
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
|
||||
//
|
||||
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
|
||||
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
|
||||
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
|
||||
//
|
||||
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
|
||||
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
|
||||
// net/http package upon receiving a request from a Conn.
|
||||
//
|
||||
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
|
||||
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
|
||||
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
|
||||
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
|
||||
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
|
||||
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
|
||||
// function.
|
||||
//
|
||||
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
|
||||
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
|
||||
// accepted.
|
||||
//
|
||||
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
|
||||
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
|
||||
// Accept will return error (so that the caller won't call again)
|
||||
//
|
||||
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
|
||||
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
|
||||
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
|
||||
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
|
||||
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
|
||||
//
|
||||
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
|
||||
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
|
||||
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
|
||||
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
|
||||
// websocket.Conn
|
||||
//
|
||||
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
|
||||
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
|
||||
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
|
||||
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
|
||||
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
|
||||
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
|
||||
//
|
||||
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
|
||||
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
|
||||
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
|
||||
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
|
||||
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
|
||||
// execution will block until the reference to util.WebSocketConn is ready.
|
||||
|
||||
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
|
||||
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
|
||||
// fake a conn that returns the first packet on first read
|
||||
type firstBuffedConn struct {
|
||||
net.Conn
|
||||
firstRead bool
|
||||
firstPacket []byte
|
||||
}
|
||||
|
||||
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
|
||||
if !c.firstRead {
|
||||
c.firstRead = true
|
||||
copy(buf, c.firstPacket)
|
||||
n := len(c.firstPacket)
|
||||
c.firstPacket = []byte{}
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(buf)
|
||||
}
|
||||
|
||||
type wsOnceListener struct {
|
||||
done bool
|
||||
c *firstBuffedConn
|
||||
}
|
||||
|
||||
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
|
||||
// http.Server. This is an acceptor that accepts only one Conn
|
||||
func newWsAcceptor(conn net.Conn, first []byte) *wsOnceListener {
|
||||
f := make([]byte, len(first))
|
||||
copy(f, first)
|
||||
return &wsOnceListener{
|
||||
c: &firstBuffedConn{Conn: conn, firstPacket: f},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wsOnceListener) Accept() (net.Conn, error) {
|
||||
if w.done {
|
||||
return nil, errors.New("already accepted")
|
||||
}
|
||||
w.done = true
|
||||
return w.c, nil
|
||||
}
|
||||
|
||||
func (w *wsOnceListener) Close() error {
|
||||
w.done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wsOnceListener) Addr() net.Addr {
|
||||
return w.c.LocalAddr()
|
||||
}
|
||||
|
||||
type wsHandshakeHandler struct {
|
||||
conn net.Conn
|
||||
finished chan struct{}
|
||||
}
|
||||
|
||||
// the handler to turn a net.Conn into a websocket.Conn
|
||||
func newWsHandshakeHandler() *wsHandshakeHandler {
|
||||
return &wsHandshakeHandler{finished: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upgrade connection to ws: %v", err)
|
||||
return
|
||||
}
|
||||
ws.conn = &common.WebSocketConn{Conn: c}
|
||||
ws.finished <- struct{}{}
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/cbeuw/connutil"
|
||||
)
|
||||
|
||||
func TestFirstBuffedConn_Read(t *testing.T) {
|
||||
mockConn, writingEnd := connutil.AsyncPipe()
|
||||
|
||||
expectedFirstPacket := []byte{1, 2, 3}
|
||||
firstBuffedConn := &firstBuffedConn{
|
||||
Conn: mockConn,
|
||||
firstRead: false,
|
||||
firstPacket: expectedFirstPacket,
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, err := firstBuffedConn.Read(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(expectedFirstPacket, buf[:n]) {
|
||||
t.Error("first read doesn't produce given packet")
|
||||
return
|
||||
}
|
||||
|
||||
expectedSecondPacket := []byte{4, 5, 6, 7}
|
||||
writingEnd.Write(expectedSecondPacket)
|
||||
n, err = firstBuffedConn.Read(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(expectedSecondPacket, buf[:n]) {
|
||||
t.Error("second read doesn't produce subsequently written packet")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestWsAcceptor(t *testing.T) {
|
||||
mockConn := connutil.Discard()
|
||||
expectedFirstPacket := []byte{1, 2, 3}
|
||||
|
||||
wsAcceptor := newWsAcceptor(mockConn, expectedFirstPacket)
|
||||
_, err := wsAcceptor.Accept()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = wsAcceptor.Accept()
|
||||
if err == nil {
|
||||
t.Error("accepting second time doesn't return error")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,575 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/client"
|
||||
"github.com/cbeuw/Cloak/internal/common"
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/cbeuw/Cloak/internal/server"
|
||||
"github.com/cbeuw/connutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const numConns = 200 // -race option limits the number of goroutines to 8192
|
||||
|
||||
func serveTCPEcho(l net.Listener) {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
_, err := io.Copy(conn, conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func serveUDPEcho(listener *connutil.PipeListener) {
|
||||
for {
|
||||
conn, err := listener.ListenPacket("udp", "")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
const bufSize = 32 * 1024
|
||||
go func(conn net.PacketConn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, bufSize)
|
||||
for {
|
||||
r, _, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
w, err := conn.WriteTo(buf[:r], nil)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
if r != w {
|
||||
log.Error("written not eqal to read")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=")
|
||||
var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=")
|
||||
|
||||
var basicUDPConfig = client.RawConfig{
|
||||
ServerName: "www.example.com",
|
||||
ProxyMethod: "openvpn",
|
||||
EncryptionMethod: "plain",
|
||||
UID: bypassUID[:],
|
||||
PublicKey: publicKey,
|
||||
NumConn: 4,
|
||||
UDP: true,
|
||||
Transport: "direct",
|
||||
RemoteHost: "fake.com",
|
||||
RemotePort: "9999",
|
||||
LocalHost: "127.0.0.1",
|
||||
LocalPort: "9999",
|
||||
}
|
||||
|
||||
var basicTCPConfig = client.RawConfig{
|
||||
ServerName: "www.example.com",
|
||||
ProxyMethod: "shadowsocks",
|
||||
EncryptionMethod: "plain",
|
||||
UID: bypassUID[:],
|
||||
PublicKey: publicKey,
|
||||
NumConn: 4,
|
||||
UDP: false,
|
||||
Transport: "direct",
|
||||
RemoteHost: "fake.com",
|
||||
RemotePort: "9999",
|
||||
LocalHost: "127.0.0.1",
|
||||
LocalPort: "9999",
|
||||
BrowserSig: "firefox",
|
||||
}
|
||||
|
||||
var singleplexTCPConfig = client.RawConfig{
|
||||
ServerName: "www.example.com",
|
||||
ProxyMethod: "shadowsocks",
|
||||
EncryptionMethod: "plain",
|
||||
UID: bypassUID[:],
|
||||
PublicKey: publicKey,
|
||||
NumConn: 0,
|
||||
UDP: false,
|
||||
Transport: "direct",
|
||||
RemoteHost: "fake.com",
|
||||
RemotePort: "9999",
|
||||
LocalHost: "127.0.0.1",
|
||||
LocalPort: "9999",
|
||||
BrowserSig: "safari",
|
||||
}
|
||||
|
||||
func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) {
|
||||
lcl, rmt, auth, err := rawConfig.ProcessRawConfig(state)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return lcl, rmt, auth
|
||||
}
|
||||
|
||||
func basicServerState(ws common.WorldState) *server.State {
|
||||
var serverConfig = server.RawConfig{
|
||||
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
|
||||
BindAddr: []string{"fake.com:9999"},
|
||||
BypassUID: [][]byte{bypassUID[:]},
|
||||
RedirAddr: "fake.com:9999",
|
||||
PrivateKey: privateKey,
|
||||
KeepAlive: 15,
|
||||
CncMode: false,
|
||||
}
|
||||
state, err := server.InitState(serverConfig, ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
type mockUDPDialer struct {
|
||||
addrCh chan *net.UDPAddr
|
||||
raddr *net.UDPAddr
|
||||
}
|
||||
|
||||
func (m *mockUDPDialer) Dial(network, address string) (net.Conn, error) {
|
||||
if m.raddr == nil {
|
||||
m.raddr = <-m.addrCh
|
||||
}
|
||||
return net.DialUDP("udp", nil, m.raddr)
|
||||
}
|
||||
|
||||
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) {
|
||||
// redirecting web server
|
||||
// ^
|
||||
// |
|
||||
// |
|
||||
// redirFromCkServerL
|
||||
// |
|
||||
// |
|
||||
// proxy client ----proxyToCkClientD----> ck-client ------> ck-server ----proxyFromCkServerL----> proxy server
|
||||
// ^
|
||||
// |
|
||||
// |
|
||||
// netToCkServerD
|
||||
// |
|
||||
// |
|
||||
// whatever connection initiator (including a proper ck-client)
|
||||
|
||||
netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024)
|
||||
|
||||
clientSeshMaker := func() *mux.Session {
|
||||
ai := ai
|
||||
quad := make([]byte, 4)
|
||||
common.RandRead(ai.WorldState.Rand, quad)
|
||||
ai.SessionId = binary.BigEndian.Uint32(quad)
|
||||
return client.MakeSession(rcc, ai, netToCkServerD)
|
||||
}
|
||||
|
||||
var proxyToCkClientD common.Dialer
|
||||
if ai.Unordered {
|
||||
// We can only "dial" a single UDP connection as we can't send packets from different context
|
||||
// to a single UDP listener
|
||||
addrCh := make(chan *net.UDPAddr, 1)
|
||||
mDialer := &mockUDPDialer{
|
||||
addrCh: addrCh,
|
||||
}
|
||||
acceptor := func() (*net.UDPConn, error) {
|
||||
laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
addrCh <- conn.LocalAddr().(*net.UDPAddr)
|
||||
return conn, err
|
||||
}
|
||||
go client.RouteUDP(acceptor, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
|
||||
proxyToCkClientD = mDialer
|
||||
} else {
|
||||
var proxyToCkClientL *connutil.PipeListener
|
||||
proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024)
|
||||
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
|
||||
}
|
||||
|
||||
// set up server
|
||||
ckServerToProxyD, proxyFromCkServerL := connutil.DialerListener(10 * 1024)
|
||||
ckServerToWebD, redirFromCkServerL := connutil.DialerListener(10 * 1024)
|
||||
serverState.ProxyDialer = ckServerToProxyD
|
||||
serverState.RedirDialer = ckServerToWebD
|
||||
|
||||
go server.Serve(ckServerListener, serverState)
|
||||
|
||||
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
|
||||
}
|
||||
|
||||
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, conn := range conns {
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
testData := make([]byte, msgLen)
|
||||
rand.Read(testData)
|
||||
|
||||
// we cannot call t.Fatalf in concurrent contexts
|
||||
n, err := conn.Write(testData)
|
||||
if n != msgLen {
|
||||
t.Errorf("written only %v, err %v", n, err)
|
||||
return
|
||||
}
|
||||
|
||||
recvBuf := make([]byte, msgLen)
|
||||
_, err = io.ReadFull(conn, recvBuf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read back: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(testData, recvBuf) {
|
||||
t.Errorf("echoed data not correct")
|
||||
return
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestUDP(t *testing.T) {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
|
||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
|
||||
sta := basicServerState(worldState)
|
||||
|
||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("simple send", func(t *testing.T) {
|
||||
pxyClientConn, err := proxyToCkClientD.Dial("udp", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
const testDataLen = 1500
|
||||
testData := make([]byte, testDataLen)
|
||||
rand.Read(testData)
|
||||
n, err := pxyClientConn.Write(testData)
|
||||
if n != testDataLen {
|
||||
t.Errorf("wrong length sent: %v", n)
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
pxyServerConn, err := proxyFromCkServerL.ListenPacket("", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
recvBuf := make([]byte, testDataLen+100)
|
||||
r, _, err := pxyServerConn.ReadFrom(recvBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !bytes.Equal(testData, recvBuf[:r]) {
|
||||
t.Error("read wrong data")
|
||||
}
|
||||
})
|
||||
|
||||
const echoMsgLen = 1024
|
||||
t.Run("user echo", func(t *testing.T) {
|
||||
go serveUDPEcho(proxyFromCkServerL)
|
||||
var conn [1]net.Conn
|
||||
conn[0], err = proxyToCkClientD.Dial("udp", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
runEchoTest(t, conn[:], echoMsgLen)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestTCPSingleplex(t *testing.T) {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
|
||||
sta := basicServerState(worldState)
|
||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const echoMsgLen = 1 << 16
|
||||
go serveTCPEcho(proxyFromCkServerL)
|
||||
|
||||
proxyConn1, err := proxyToCkClientD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen)
|
||||
user, err := sta.Panel.GetUser(ai.UID[:])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch user: %v", err)
|
||||
}
|
||||
|
||||
if user.NumSession() != 1 {
|
||||
t.Error("no session were made on first connection establishment")
|
||||
}
|
||||
|
||||
proxyConn2, err := proxyToCkClientD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
|
||||
if user.NumSession() != 2 {
|
||||
t.Error("no extra session were made on second connection establishment")
|
||||
}
|
||||
|
||||
// Both conns should work
|
||||
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen)
|
||||
|
||||
proxyConn1.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return user.NumSession() == 1
|
||||
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
|
||||
|
||||
// conn2 should still work
|
||||
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
|
||||
|
||||
var conns [numConns]net.Conn
|
||||
for i := 0; i < numConns; i++ {
|
||||
conns[i], err = proxyToCkClientD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
runEchoTest(t, conns[:], echoMsgLen)
|
||||
|
||||
}
|
||||
|
||||
func TestTCPMultiplex(t *testing.T) {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||
|
||||
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
||||
sta := basicServerState(worldState)
|
||||
|
||||
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("user echo single", func(t *testing.T) {
|
||||
for i := 0; i < 18; i += 2 {
|
||||
dataLen := 1 << i
|
||||
writeData := make([]byte, dataLen)
|
||||
rand.Read(writeData)
|
||||
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
|
||||
go serveTCPEcho(proxyFromCkServerL)
|
||||
conn, err := proxyToCkClientD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
n, err := conn.Write(writeData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n != dataLen {
|
||||
t.Errorf("write length doesn't match up: %v, expected %v", n, dataLen)
|
||||
}
|
||||
|
||||
recvBuf := make([]byte, dataLen)
|
||||
_, err = io.ReadFull(conn, recvBuf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !bytes.Equal(writeData, recvBuf) {
|
||||
t.Error("echoed data incorrect")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
const echoMsgLen = 16384
|
||||
t.Run("user echo", func(t *testing.T) {
|
||||
go serveTCPEcho(proxyFromCkServerL)
|
||||
var conns [numConns]net.Conn
|
||||
for i := 0; i < numConns; i++ {
|
||||
conns[i], err = proxyToCkClientD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
runEchoTest(t, conns[:], echoMsgLen)
|
||||
})
|
||||
|
||||
t.Run("redir echo", func(t *testing.T) {
|
||||
go serveTCPEcho(redirFromCkServerL)
|
||||
var conns [numConns]net.Conn
|
||||
for i := 0; i < numConns; i++ {
|
||||
conns[i], err = netToCkServerD.Dial("", "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
runEchoTest(t, conns[:], echoMsgLen)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClosingStreamsFromProxy(t *testing.T) {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||
|
||||
for clientConfigName, clientConfig := range map[string]client.RawConfig{"basic": basicTCPConfig, "singleplex": singleplexTCPConfig} {
|
||||
clientConfig := clientConfig
|
||||
clientConfigName := clientConfigName
|
||||
t.Run(clientConfigName, func(t *testing.T) {
|
||||
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
|
||||
sta := basicServerState(worldState)
|
||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("closing from server", func(t *testing.T) {
|
||||
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||
clientConn.Write(make([]byte, 16))
|
||||
serverConn, _ := proxyFromCkServerL.Accept()
|
||||
serverConn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
_, err := clientConn.Read(make([]byte, 16))
|
||||
return err != nil
|
||||
}, time.Second, 10*time.Millisecond, "closing stream on server side is not reflected to the client")
|
||||
})
|
||||
|
||||
t.Run("closing from client", func(t *testing.T) {
|
||||
// closing stream on client side
|
||||
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||
clientConn.Write(make([]byte, 16))
|
||||
serverConn, _ := proxyFromCkServerL.Accept()
|
||||
clientConn.Close()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
_, err := serverConn.Read(make([]byte, 16))
|
||||
return err != nil
|
||||
}, time.Second, 10*time.Millisecond, "closing stream on client side is not reflected to the server")
|
||||
})
|
||||
|
||||
t.Run("send then close", func(t *testing.T) {
|
||||
testData := make([]byte, 24*1024)
|
||||
rand.Read(testData)
|
||||
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||
go func() {
|
||||
clientConn.Write(testData)
|
||||
// it takes time for this written data to be copied asynchronously
|
||||
// into ck-server's domain. If the pipe is closed before that, read
|
||||
// by ck-client in RouteTCP will fail as we have closed it.
|
||||
time.Sleep(700 * time.Millisecond)
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
readBuf := make([]byte, len(testData))
|
||||
serverConn, err := proxyFromCkServerL.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("failed to accept a connection delievering data sent before closing: %v", err)
|
||||
}
|
||||
_, err = io.ReadFull(serverConn, readBuf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read data sent before closing: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIntegration(b *testing.B) {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
||||
sta := basicServerState(worldState)
|
||||
const bufSize = 16 * 1024
|
||||
|
||||
encryptionMethods := map[string]byte{
|
||||
"plain": mux.EncryptionMethodPlain,
|
||||
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
|
||||
"aes-256-gcm": mux.EncryptionMethodAES256GCM,
|
||||
"aes-128-gcm": mux.EncryptionMethodAES128GCM,
|
||||
}
|
||||
|
||||
for name, method := range encryptionMethods {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
ai.EncryptionMethod = method
|
||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.Run("single stream bandwidth", func(b *testing.B) {
|
||||
more := make(chan int, 10)
|
||||
go func() {
|
||||
// sender
|
||||
writeBuf := make([]byte, bufSize+100)
|
||||
serverConn, _ := proxyFromCkServerL.Accept()
|
||||
for {
|
||||
serverConn.Write(writeBuf)
|
||||
<-more
|
||||
}
|
||||
}()
|
||||
// receiver
|
||||
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||
readBuf := make([]byte, bufSize)
|
||||
clientConn.Write([]byte{1}) // to make server accept
|
||||
b.SetBytes(bufSize)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
io.ReadFull(clientConn, readBuf)
|
||||
// ask for more
|
||||
more <- 0
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("single stream latency", func(b *testing.B) {
|
||||
clientConn, _ := proxyToCkClientD.Dial("", "")
|
||||
buf := []byte{1}
|
||||
clientConn.Write(buf)
|
||||
serverConn, _ := proxyFromCkServerL.Accept()
|
||||
serverConn.Read(buf)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
clientConn.Write(buf)
|
||||
serverConn.Read(buf)
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
package test
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
prand "math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte {
|
||||
block, _ := aes.NewCipher(key)
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(ciphertext, plaintext)
|
||||
return ciphertext
|
||||
}
|
||||
|
||||
func AESDecrypt(iv []byte, key []byte, ciphertext []byte) []byte {
|
||||
ret := make([]byte, len(ciphertext))
|
||||
copy(ret, ciphertext) // Because XORKeyStream is inplace, but we don't want the input to be changed
|
||||
block, _ := aes.NewCipher(key)
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(ret, ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
// PsudoRandBytes returns a byte slice filled with psudorandom bytes generated by the seed
|
||||
func PsudoRandBytes(length int, seed int64) []byte {
|
||||
r := prand.New(prand.NewSource(seed))
|
||||
ret := make([]byte, length)
|
||||
r.Read(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
// ReadTLS reads TLS data according to its record layer
|
||||
func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
|
||||
// TCP is a stream. Multiple TLS messages can arrive at the same time,
|
||||
// a single message can also be segmented due to MTU of the IP layer.
|
||||
// This function guareentees a single TLS message to be read and everything
|
||||
// else is left in the buffer.
|
||||
i, err := io.ReadFull(conn, buffer[:5])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dataLength := int(binary.BigEndian.Uint16(buffer[3:5]))
|
||||
if dataLength > len(buffer) {
|
||||
err = errors.New("Reading TLS message: message size greater than buffer. message size: " + strconv.Itoa(dataLength))
|
||||
return
|
||||
}
|
||||
left := dataLength
|
||||
readPtr := 5
|
||||
|
||||
for left != 0 {
|
||||
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
|
||||
// if left = buffer size, the entire buffer is all there left to read
|
||||
// if left < buffer size (i.e. multiple messages came together),
|
||||
// only the message we want is read
|
||||
i, err = io.ReadFull(conn, buffer[readPtr:readPtr+left])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
left -= i
|
||||
readPtr += i
|
||||
}
|
||||
|
||||
n = 5 + dataLength
|
||||
return
|
||||
}
|
||||
|
||||
// AddRecordLayer adds record layer to data
|
||||
func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
||||
ret := make([]byte, 5+len(input))
|
||||
copy(ret[0:1], typ)
|
||||
copy(ret[1:3], ver)
|
||||
copy(ret[3:5], length)
|
||||
copy(ret[5:], input)
|
||||
return ret
|
||||
}
|
||||
|
||||
// PeelRecordLayer peels off the record layer
|
||||
func PeelRecordLayer(data []byte) []byte {
|
||||
ret := data[5:]
|
||||
return ret
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue