Compare commits

..

No commits in common. "master" and "v0.1.0" have entirely different histories.

99 changed files with 2878 additions and 10666 deletions

View File

@ -1,91 +0,0 @@
name: Build and test
on: [ push ]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '^1.24' # The Go version to download (if necessary) and use.
- run: go test -race -coverprofile coverage.txt -coverpkg ./... -covermode atomic ./...
- uses: codecov/codecov-action@v4
with:
files: coverage.txt
token: ${{ secrets.CODECOV_TOKEN }}
compat-test:
runs-on: ubuntu-latest
strategy:
matrix:
encryption-method: [ plain, chacha20-poly1305 ]
num-conn: [ 0, 1, 4 ]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '^1.24'
- name: Build Cloak
run: make
- name: Create configs
run: |
mkdir config
cat << EOF > config/ckclient.json
{
"Transport": "direct",
"ProxyMethod": "iperf",
"EncryptionMethod": "${{ matrix.encryption-method }}",
"UID": "Q4GAXHVgnDLXsdTpw6bmoQ==",
"PublicKey": "4dae/bF43FKGq+QbCc5P/E/MPM5qQeGIArjmJEHiZxc=",
"ServerName": "cloudflare.com",
"BrowserSig": "firefox",
"NumConn": ${{ matrix.num-conn }}
}
EOF
cat << EOF > config/ckserver.json
{
"ProxyBook": {
"iperf": [
"tcp",
"127.0.0.1:5201"
]
},
"BindAddr": [
":8443"
],
"BypassUID": [
"Q4GAXHVgnDLXsdTpw6bmoQ=="
],
"RedirAddr": "cloudflare.com",
"PrivateKey": "AAaskZJRPIAbiuaRLHsvZPvE6gzOeSjg+ZRg1ENau0Y="
}
EOF
- name: Start iperf3 server
run: docker run -d --name iperf-server --network host ajoergensen/iperf3:latest --server
- name: Test new client against old server
run: |
docker run -d --name old-cloak-server --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-server -c config/ckserver.json --verbosity debug
build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug | tee new-cloak-client.log &
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
docker stop old-cloak-server
- name: Test old client against new server
run: |
build/ck-server -c config/ckserver.json --verbosity debug | tee new-cloak-server.log &
docker run -d --name old-cloak-client --network host -v $PWD/config:/go/Cloak/config cbeuw/cloak:latest build/ck-client -c config/ckclient.json -s 127.0.0.1 -p 8443 --verbosity debug
docker run --network host ajoergensen/iperf3:latest --client 127.0.0.1 -p 1984
docker stop old-cloak-client
- name: Dump docker logs
if: always()
run: |
docker container logs iperf-server > iperf-server.log
docker container logs old-cloak-server > old-cloak-server.log
docker container logs old-cloak-client > old-cloak-client.log
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.encryption-method }}-${{ matrix.num-conn }}-conn-logs
path: ./*.log

View File

@ -1,50 +0,0 @@
on:
push:
tags:
- 'v*'
name: Create Release
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: |
export PATH=${PATH}:`go env GOPATH`/bin
v=${GITHUB_REF#refs/*/} ./release.sh
- name: Release
uses: softprops/action-gh-release@v1
with:
files: release/*
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
build-docker:
runs-on: ubuntu-latest
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: |
cbeuw/cloak
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
with:
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

6
.gitignore vendored
View File

@ -1,6 +0,0 @@
corpus/
suppressions/
crashers/
*.zip
.idea/
build/

View File

@ -1,5 +0,0 @@
FROM golang:latest
RUN git clone https://github.com/cbeuw/Cloak.git
WORKDIR Cloak
RUN make

674
LICENSE
View File

@ -1,674 +0,0 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@ -7,12 +7,10 @@ version=$(shell ver=$$(git log -n 1 --pretty=oneline --format=%D | awk -F, '{pri
echo $$ver)
client:
mkdir -p build
go build -ldflags "-X main.version=${version}" ./cmd/ck-client
mv ck-client* ./build
server:
mkdir -p build
go build -ldflags "-X main.version=${version}" ./cmd/ck-server
mv ck-server* ./build

259
README.md
View File

@ -1,242 +1,41 @@
[![Build Status](https://github.com/cbeuw/Cloak/workflows/Build%20and%20test/badge.svg)](https://github.com/cbeuw/Cloak/actions)
[![codecov](https://codecov.io/gh/cbeuw/Cloak/branch/master/graph/badge.svg)](https://codecov.io/gh/cbeuw/Cloak)
[![Go Report Card](https://goreportcard.com/badge/github.com/cbeuw/Cloak)](https://goreportcard.com/report/github.com/cbeuw/Cloak)
[![Donate](https://img.shields.io/badge/Donate-PayPal-green.svg)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)
# Cloak
A shadowsocks plugin that obfuscates the traffic as normal HTTPS traffic and disguises the proxy server as a normal webserver.
<p align="center">
<img src="https://user-images.githubusercontent.com/7034308/96387206-3e214100-1198-11eb-8917-689d7c56e0cd.png" />
<img src="https://user-images.githubusercontent.com/7034308/155593583-f22bcfe2-ac22-4afb-9288-1a0e8a791a0d.svg" />
</p>
**This is an active WIP. Everything is subject to change.**
<p align="center">
<img src="https://user-images.githubusercontent.com/7034308/155629720-54dd8758-ec98-4fed-b603-623f0ad83b6c.svg" />
</p>
This project is based on [GoQuiet](https://github.com/cbeuw/GoQuiet). The most significant difference is that, in GoQuiet, a new TCP connection is establieshed and a TLS handshake is done between the client and the proxy server each time a connection is made to ssclient, whereas in Cloak all the traffic is multiplexed through a fixed amount of consistant TCP connections between the client and the proxy server. The major benefits are:
Cloak is a [pluggable transport](https://datatracker.ietf.org/meeting/103/materials/slides-103-pearg-pt-slides-01) that enhances
traditional proxy tools like OpenVPN to evade [sophisticated censorship](https://en.wikipedia.org/wiki/Deep_packet_inspection) and [data discrimination](https://en.wikipedia.org/wiki/Net_bias).
- Significantly quicker establishment of new connections as TLS handshake is only done on the startup of the client
Cloak is not a standalone proxy program. Rather, it works by masquerading proxied traffic as normal web browsing
activities. In contrast to traditional tools which have very prominent traffic fingerprints and can be blocked by simple filtering rules,
it's very difficult to precisely target Cloak with little false positives. This increases the collateral damage to censorship actions as
attempts to block Cloak could also damage services the censor state relies on.
- More realistic traffic pattern
To any third party observer, a host running Cloak server is indistinguishable from an innocent web server. Both while
passively observing traffic flow to and from the server, as well as while actively probing the behaviours of a Cloak
server. This is achieved through the use a series
of [cryptographic steganography techniques](https://github.com/cbeuw/Cloak/wiki/Steganography-and-encryption).
Cloak can be used in conjunction with any proxy program that tunnels traffic through TCP or
UDP, such as Shadowsocks, OpenVPN and Tor. Multiple proxy servers can be running on the same server host and
Cloak server will act as a reverse proxy, bridging clients with their desired proxy end.
Cloak multiplexes traffic through multiple underlying TCP connections which reduces head-of-line blocking and eliminates
TCP handshake overhead. This also makes the traffic pattern more similar to real websites.
Cloak provides multi-user support, allowing multiple clients to connect to the proxy server on the same port (443 by
default). It also provides traffic management features such as usage credit and bandwidth control. This allows a proxy
server to serve multiple users even if the underlying proxy software wasn't designed for multiple users
Cloak also supports tunneling through an intermediary CDN server such as Amazon Cloudfront. Such services are so widely used,
attempts to disrupt traffic to them can lead to very high collateral damage for the censor.
## Quick Start
To quickly deploy Cloak with Shadowsocks on a server, you can run
this [script](https://github.com/HirbodBehnam/Shadowsocks-Cloak-Installer/blob/master/Cloak2-Installer.sh) written by
@HirbodBehnam
Table of Contents
=================
* [Quick Start](#quick-start)
* [Build](#build)
* [Configuration](#configuration)
* [Server](#server)
* [Client](#client)
* [Setup](#setup)
* [Server](#server-1)
* [To add users](#to-add-users)
* [Unrestricted users](#unrestricted-users)
* [Users subject to bandwidth and credit controls](#users-subject-to-bandwidth-and-credit-controls)
* [Client](#client-1)
* [Support me](#support-me)
Besides, Cloak allows multiple users to use one server **on a single port**. QoS restrictions such as bandwidth limitation and data cap can also be managed.
## Build
```bash
git clone https://github.com/cbeuw/Cloak
cd Cloak
go get ./...
make
```
Built binaries will be in `build` folder.
## Configuration
Examples of configuration files can be found under `example_config` folder.
### Server
`RedirAddr` is the redirection address when the incoming traffic is not from a Cloak client. Ideally it should be set to
a major website allowed by the censor (e.g. `www.bing.com`)
`BindAddr` is a list of addresses Cloak will bind and listen to (e.g. `[":443",":80"]` to listen to port 443 and 80 on
all interfaces)
`ProxyBook` is an object whose key is the name of the ProxyMethod used on the client-side (case-sensitive). Its value is
an array whose first element is the protocol, and the second element is an `IP:PORT` string of the upstream proxy server
that Cloak will forward the traffic to.
Example:
```json
{
"ProxyBook": {
"shadowsocks": [
"tcp",
"localhost:51443"
],
"openvpn": [
"tcp",
"localhost:12345"
]
}
}
```
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`.
`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will
create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`.
This field also has no effect if `AdminUID` isn't a valid UID or is empty.
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
### Client
`UID` is your UID in base64.
`Transport` can be either `direct` or `CDN`. If the server host wishes you to connect to it directly, use `direct`. If
instead a CDN is used, use `CDN`.
`PublicKey` is the static curve25519 public key in base64, given by the server admin.
`ProxyMethod` is the name of the proxy method you are using. This must match one of the entries in the
server's `ProxyBook` exactly.
`EncryptionMethod` is the name of the encryption algorithm you want Cloak to use. Options are `plain`, `aes-256-gcm` (
synonymous to `aes-gcm`), `aes-128-gcm`, and `chacha20-poly1305`. Note: Cloak isn't intended to provide transport
security. The point of encryption is to hide fingerprints of proxy protocols and render the payload statistically
random-like. **You may only leave it as `plain` if you are certain that your underlying proxy tool already provides BOTH
encryption and authentication (via AEAD or similar techniques).**
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. Use `random` to randomize the server name for every connection made.
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
the alternative domains.
Example:
```json
{
"ServerName": "bing.com",
"AlternativeNames": ["cloudflare.com", "github.com"]
}
```
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with
the CDN server, this domain name will be used in the `Host` header of the HTTP request to ask the CDN server to
establish a WebSocket connection with this host.
`CDNWsUrlPath` is the url path used to build websocket request sent under `CDN` mode, and also only has effect
when `Transport` is set to `CDN`. If unset, it will default to "/". This option is used to build the first line of the
HTTP request after a TLS session is extablished. It's mainly for a Cloak server behind a reverse proxy, while only
requests under specific url path are forwarded.
`NumConn` is the amount of underlying TCP connections you want to use. The default of 4 should be appropriate for most
people. Setting it too high will hinder the performance. Setting it to 0 will disable connection multiplexing and each
TCP connection will spawn a separate short-lived session that will be closed after it is terminated. This makes it
behave like GoQuiet. This maybe useful for people with unstable connections.
`BrowserSig` is the browser you want to **appear** to be using. It's not relevant to the browser you are actually using.
Currently, `chrome`, `firefox` and `safari` are supported.
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
Cloak server. Zero or negative value disables it. Default is 0 (disabled). Warning: Enabling it might make your server
more detectable as a proxy, but it will make the Cloak client detect internet interruption more quickly.
`StreamTimeout` is the number of seconds of Cloak waits for an incoming connection from a proxy program to send any
data, after which the connection will be closed by Cloak. Cloak will not enforce any timeout on TCP connections after it
is established.
Simply `make client` and `make server`. Output binary will be in the build folder
## Setup
### For the administrator of the server
0. [Install and configure shadowsocks-libev on your server](https://github.com/shadowsocks/shadowsocks-libev#installation)
1. Clone this repo onto your server
2. Build and run ck-server -k. The base64 string before the comma is the public key, the one after the comma is the private key
3. Run `ck-server -u`. This will be used as the AdminUID
4. Put the private key and the AdminUID you obtained previously into config/ckserver.json
5. Edit the configuration file of shadowsocks-libev (default location is /etc/shadowsocks-libev/config.json). Let `server_port` be `443`, `plugin` be the full path to the ck-server binary and `plugin_opts` be the full path to ckserver.json. If the fields `plugin` and `plugin_opts` were not present originally, add these fields to the config file.
6. Run ss-server as root (because we are binding to TCP port 443)
### Server
#### If you want to add more users
1. Run ck-server -u to generate a new UID
2. On your client, run `ck-client -a -c <path-to-ckclient.json>` to enter admin mode
3. Input as prompted, that is your ip:port of the server and your AdminUID. Enter 4 to create a new user.
4. Enter the UID in your ckclient.json as the prompted UID, enter SessionsCap (maximum amount of concurrent sessions a user can have), UpRate and DownRate (in bytes/s), UpCredit and DownCredit (in bytes) and ExpiryTime (as a unix epoch)
5. Give your PUBLIC key and the newly generated UID to the new user
0. Install at least one underlying proxy server (e.g. OpenVPN, Shadowsocks).
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
2. Run `ck-server -key`. The **public** should be given to users, the **private** key should be kept secret.
3. (Skip if you only want to add unrestricted users) Run `ck-server -uid`. The new UID will be used as `AdminUID`.
4. Copy example_config/ckserver.json into a desired location. Change `PrivateKey` to the private key you just obtained;
change `AdminUID` to the UID you just obtained.
5. Configure your underlying proxy server so that they all listen on localhost. Edit `ProxyBook` in the configuration
file accordingly
6. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
Run `sudo ck-server -c <path to ckserver.json>`. ck-server needs root privilege because it binds to a low numbered
port (443). Alternatively you can follow https://superuser.com/a/892391 to avoid granting ck-server root privilege
unnecessarily.
Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start ck-server.
#### To add users
##### Unrestricted users
Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.json`
##### Users subject to bandwidth and credit controls
0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db`
in `DatabasePath` (Cloak will create this file for you if it didn't already exist).
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
enter admin mode
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data
entered into this site are processed between your browser and the Cloak API endpoint you specified. Alternatively you
can download the repo at https://github.com/cbeuw/Cloak-panel and open `index.html` in a browser. No web server is
required).
3. Type in `127.0.0.1:<the port you entered in step 1>` as the API Base, and click `List`.
4. You can add in more users by clicking the `+` panel
Note: the user database is persistent as it's in-disk. You don't need to add the users again each time you start
ck-server.
### Client
**Android client is available here: https://github.com/cbeuw/Cloak-android**
0. Install the underlying proxy client corresponding to what the server has.
1. Download [the latest release](https://github.com/cbeuw/Cloak/releases) or clone and build this repo.
2. Obtain the public key and your UID from the administrator of your server
3. Copy `example_config/ckclient.json` into a location of your choice. Enter the `UID` and `PublicKey` you have
obtained. Set `ProxyMethod` to match exactly the corresponding entry in `ProxyBook` on the server end
4. [Configure the proxy program.](https://github.com/cbeuw/Cloak/wiki/Underlying-proxy-configuration-guides)
Run `ck-client -c <path to ckclient.json> -s <ip of your server>`
## Support me
If you find this project useful, you can visit my [merch store](https://www.redbubble.com/people/cbeuw/explore);
alternatively you can donate directly to me
[![Donate](https://img.shields.io/badge/Donate-PayPal-green.svg)](https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=SAUYKGSREP8GL&source=url)
BTC: `bc1q59yvpnh0356qq9vf0j2y7hx36t9ysap30spx9h`
ETH: `0x8effF29a8F9bD38A367580527AC303972c92b60c`
### Instructions for clients
0. Install and configure a version of shadowsocks client that supports plugins (such as shadowsocks-libev and shadowsocks-windows)
1. Clone this repo and build ck-client
2. Obtain the PUBLIC key and your UID (or the AdminUID, if you are the server admin) from the administrator of your server
3. Put the public key and the UID you obtained into config/ckclient.json
4. Configure your shadowsocks client with your server information. The field `plugin` should be the path to ck-server binary and `plugin_opts` should be the path to ckclient.json

276
cmd/ck-client/admin.go Normal file
View File

@ -0,0 +1,276 @@
//build !android
package main
// TODO: rewrite this. Think of another way of admin control
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/client/TLS"
"github.com/cbeuw/Cloak/internal/util"
)
type UserInfo struct {
UID []byte
// ALL of the following fields have to be accessed atomically
SessionsCap uint32
UpRate int64
DownRate int64
UpCredit int64
DownCredit int64
ExpiryTime int64
}
type administrator struct {
adminConn net.Conn
adminUID []byte
}
func adminPrompt(sta *client.State) error {
a, err := adminHandshake(sta)
if err != nil {
log.Println(err)
return err
}
fmt.Println(`1 listActiveUsers none []uids
2 listAllUsers none []userinfo
3 getUserInfo uid userinfo
4 addNewUser userinfo ok
5 delUser uid ok
6 syncMemFromDB uid ok
7 setSessionsCap uid cap ok
8 setUpRate uid rate ok
9 setDownRate uid rate ok
10 setUpCredit uid credit ok
11 setDownCredit uid credit ok
12 setExpiryTime uid time ok
13 addUpCredit uid delta ok
14 addDownCredit uid delta ok`)
buf := make([]byte, 16000)
for {
req, err := a.getRequest()
if err != nil {
log.Println(err)
continue
}
a.adminConn.Write(req)
n, err := a.adminConn.Read(buf)
if err != nil {
return err
}
resp, err := a.checkAndDecrypt(buf[:n])
if err != nil {
return err
}
fmt.Println(string(resp))
}
}
func adminHandshake(sta *client.State) (*administrator, error) {
fmt.Println("Enter the ip:port of your server")
var addr string
fmt.Scanln(&addr)
fmt.Println("Enter the admin UID")
var b64AdminUID string
fmt.Scanln(&b64AdminUID)
adminUID, err := base64.StdEncoding.DecodeString(b64AdminUID)
if err != nil {
return nil, err
}
sta.UID = adminUID
remoteConn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
clientHello := TLS.ComposeInitHandshake(sta)
_, err = remoteConn.Write(clientHello)
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
discardBuf := make([]byte, 1024)
for c := 0; c < 3; c++ {
_, err = util.ReadTLS(remoteConn, discardBuf)
if err != nil {
return nil, err
}
}
reply := TLS.ComposeReply()
_, err = remoteConn.Write(reply)
a := &administrator{remoteConn, adminUID}
return a, nil
}
func (a *administrator) getRequest() (req []byte, err error) {
promptUID := func() []byte {
fmt.Println("Enter UID")
var b64UID string
fmt.Scanln(&b64UID)
ret, _ := base64.StdEncoding.DecodeString(b64UID)
return ret
}
promptInt64 := func(name string) []byte {
fmt.Println("Enter New " + name)
var val int64
fmt.Scanln(&val)
ret := make([]byte, 8)
binary.BigEndian.PutUint64(ret, uint64(val))
return ret
}
promptUint32 := func(name string) []byte {
fmt.Println("Enter New " + name)
var val uint32
fmt.Scanln(&val)
ret := make([]byte, 4)
binary.BigEndian.PutUint32(ret, val)
return ret
}
fmt.Println("Select your command")
var cmd string
fmt.Scanln(&cmd)
switch cmd {
case "1":
req = a.request([]byte{0x01})
case "2":
req = a.request([]byte{0x02})
case "3":
UID := promptUID()
req = a.request(append([]byte{0x03}, UID...))
case "4":
var uinfo UserInfo
var b64UID string
fmt.Printf("UID:")
fmt.Scanln(&b64UID)
UID, _ := base64.StdEncoding.DecodeString(b64UID)
uinfo.UID = UID
fmt.Printf("SessionsCap:")
fmt.Scanf("%d", &uinfo.SessionsCap)
fmt.Printf("UpRate:")
fmt.Scanf("%d", &uinfo.UpRate)
fmt.Printf("DownRate:")
fmt.Scanf("%d", &uinfo.DownRate)
fmt.Printf("UpCredit:")
fmt.Scanf("%d", &uinfo.UpCredit)
fmt.Printf("DownCredit:")
fmt.Scanf("%d", &uinfo.DownCredit)
fmt.Printf("ExpiryTime:")
fmt.Scanf("%d", &uinfo.ExpiryTime)
marshed, _ := json.Marshal(uinfo)
req = a.request(append([]byte{0x04}, marshed...))
case "5":
UID := promptUID()
fmt.Println("Are you sure to delete this user? y/n")
var ans string
fmt.Scanln(&ans)
if ans != "y" && ans != "Y" {
return
}
req = a.request(append([]byte{0x05}, UID...))
case "6":
UID := promptUID()
req = a.request(append([]byte{0x06}, UID...))
case "7":
arg := make([]byte, 36)
copy(arg, promptUID())
copy(arg[32:], promptUint32("SessionsCap"))
req = a.request(append([]byte{0x07}, arg...))
case "8":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpRate"))
req = a.request(append([]byte{0x08}, arg...))
case "9":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownRate"))
req = a.request(append([]byte{0x09}, arg...))
case "10":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpCredit"))
req = a.request(append([]byte{0x0a}, arg...))
case "11":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownCredit"))
req = a.request(append([]byte{0x0b}, arg...))
case "12":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("ExpiryTime"))
req = a.request(append([]byte{0x0c}, arg...))
case "13":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpCredit to add"))
req = a.request(append([]byte{0x0d}, arg...))
case "14":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownCredit to add"))
req = a.request(append([]byte{0x0e}, arg...))
default:
return nil, errors.New("Unreconised cmd")
}
return req, nil
}
// protocol: 0[TLS record layer 5 bytes]5[IV 16 bytes]21[data][hmac 32 bytes]
func (a *administrator) request(data []byte) []byte {
dataLen := len(data)
buf := make([]byte, 5+16+dataLen+32)
buf[0] = 0x17
buf[1] = 0x03
buf[2] = 0x03
binary.BigEndian.PutUint16(buf[3:5], uint16(16+dataLen+32))
rand.Read(buf[5:21]) //iv
copy(buf[21:], data)
block, _ := aes.NewCipher(a.adminUID[0:16])
stream := cipher.NewCTR(block, buf[5:21])
stream.XORKeyStream(buf[21:21+dataLen], buf[21:21+dataLen])
mac := hmac.New(sha256.New, a.adminUID[16:32])
mac.Write(buf[5 : 21+dataLen])
copy(buf[21+dataLen:], mac.Sum(nil))
return buf
}
var ErrInvalidMac = errors.New("Mac mismatch")
func (a *administrator) checkAndDecrypt(data []byte) ([]byte, error) {
macIndex := len(data) - 32
mac := hmac.New(sha256.New, a.adminUID[16:32])
mac.Write(data[5:macIndex])
expected := mac.Sum(nil)
if !hmac.Equal(data[macIndex:], expected) {
return nil, ErrInvalidMac
}
iv := data[5:21]
ret := data[21:macIndex]
block, _ := aes.NewCipher(a.adminUID[0:16])
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ret, ret)
return ret, nil
}

View File

@ -1,69 +1,121 @@
//go:build go1.11
// +build go1.11
package main
import (
"encoding/base64"
"encoding/binary"
"flag"
"fmt"
"io"
"log"
"math/rand"
"net"
"os"
"github.com/cbeuw/Cloak/internal/common"
"sync"
"sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/client/TLS"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/util"
)
var version string
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
// The maximum size of TLS message will be 16396+12. 12 because of the stream header
// 16408 is the max TLS message size on Firefox
buf := make([]byte, 16396)
for {
i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 {
go dst.Close()
go src.Close()
return
}
i, err = dst.Write(buf[:i])
if err != nil || i == 0 {
go dst.Close()
go src.Close()
return
}
}
}
// This establishes a connection with ckserver and performs a handshake
func makeRemoteConn(sta *client.State) (net.Conn, error) {
// For android
d := net.Dialer{Control: protector}
clientHello := TLS.ComposeInitHandshake(sta)
remoteConn, err := d.Dial("tcp", sta.SS_REMOTE_HOST+":"+sta.SS_REMOTE_PORT)
if err != nil {
log.Printf("Connecting to remote: %v\n", err)
return nil, err
}
_, err = remoteConn.Write(clientHello)
if err != nil {
log.Printf("Sending ClientHello: %v\n", err)
return nil, err
}
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
discardBuf := make([]byte, 1024)
for c := 0; c < 3; c++ {
_, err = util.ReadTLS(remoteConn, discardBuf)
if err != nil {
log.Printf("Reading discarded message %v: %v\n", c, err)
return nil, err
}
}
reply := TLS.ComposeReply()
_, err = remoteConn.Write(reply)
if err != nil {
log.Printf("Sending reply to remote: %v\n", err)
return nil, err
}
return remoteConn, nil
}
func main() {
// Should be 127.0.0.1 to listen to a proxy client on this machine
// Should be 127.0.0.1 to listen to ss-local on this machine
var localHost string
// port used by proxy clients to communicate with cloak client
// server_port in ss config, ss sends data on loopback using this port
var localPort string
// The ip of the proxy server
var remoteHost string
// The proxy port,should be 443
var remotePort string
var proxyMethod string
var udp bool
var config string
var b64AdminUID string
var vpnMode bool
var tcpFastOpen bool
var pluginOpts string
isAdmin := new(bool)
log.SetFlags(log.LstdFlags | log.Lshortfile)
log_init()
ssPluginMode := os.Getenv("SS_LOCAL_HOST") != ""
verbosity := flag.String("verbosity", "info", "verbosity level")
if ssPluginMode {
config = os.Getenv("SS_PLUGIN_OPTIONS")
flag.BoolVar(&vpnMode, "V", false, "ignored.")
flag.BoolVar(&tcpFastOpen, "fast-open", false, "ignored.")
flag.Parse() // for verbosity only
if os.Getenv("SS_LOCAL_HOST") != "" {
localHost = os.Getenv("SS_LOCAL_HOST")
localPort = os.Getenv("SS_LOCAL_PORT")
remoteHost = os.Getenv("SS_REMOTE_HOST")
remotePort = os.Getenv("SS_REMOTE_PORT")
pluginOpts = os.Getenv("SS_PLUGIN_OPTIONS")
} else {
flag.StringVar(&localHost, "i", "127.0.0.1", "localHost: Cloak listens to proxy clients on this ip")
flag.StringVar(&localPort, "l", "1984", "localPort: Cloak listens to proxy clients on this port")
localHost = "127.0.0.1"
flag.StringVar(&localPort, "l", "", "localPort: same as server_port in ss config, the plugin listens to SS using this")
flag.StringVar(&remoteHost, "s", "", "remoteHost: IP of your proxy server")
flag.StringVar(&remotePort, "p", "443", "remotePort: proxy port, should be 443")
flag.BoolVar(&udp, "u", false, "udp: set this flag if the underlying proxy is using UDP protocol")
flag.StringVar(&config, "c", "ckclient.json", "config: path to the configuration file or options separated with semicolons")
flag.StringVar(&proxyMethod, "proxy", "", "proxy: the proxy method's name. It must match exactly with the corresponding entry in server's ProxyBook")
flag.StringVar(&b64AdminUID, "a", "", "adminUID: enter the adminUID to serve the admin api")
flag.StringVar(&pluginOpts, "c", "ckclient.json", "pluginOpts: path to ckclient.json or options seperated with semicolons")
askVersion := flag.Bool("v", false, "Print the version number")
isAdmin = flag.Bool("a", false, "Admin mode")
printUsage := flag.Bool("h", false, "Print this message")
// commandline arguments overrides json
flag.Parse()
if *askVersion {
fmt.Printf("ck-client %s", version)
fmt.Printf("ck-client %s\n", version)
return
}
@ -72,135 +124,113 @@ func main() {
return
}
log.Info("Starting standalone mode")
log.Println("Starting standalone mode")
}
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
lvl, err := log.ParseLevel(*verbosity)
if *isAdmin {
sta := client.InitState("", "", "", "", time.Now, 0)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatal(err)
}
log.SetLevel(lvl)
rawConfig, err := client.ParseConfig(config)
err = adminPrompt(sta)
if err != nil {
log.Fatal(err)
log.Println(err)
}
if ssPluginMode {
if rawConfig.ProxyMethod == "" {
rawConfig.ProxyMethod = "shadowsocks"
return
}
// json takes precedence over environment variables
// i.e. if json field isn't empty, use that
if rawConfig.RemoteHost == "" {
rawConfig.RemoteHost = os.Getenv("SS_REMOTE_HOST")
}
if rawConfig.RemotePort == "" {
rawConfig.RemotePort = os.Getenv("SS_REMOTE_PORT")
}
if rawConfig.LocalHost == "" {
rawConfig.LocalHost = os.Getenv("SS_LOCAL_HOST")
}
if rawConfig.LocalPort == "" {
rawConfig.LocalPort = os.Getenv("SS_LOCAL_PORT")
}
} else {
// commandline argument takes precedence over json
// if commandline argument is set, use commandline
flag.Visit(func(f *flag.Flag) {
// manually set ones
switch f.Name {
case "i":
rawConfig.LocalHost = localHost
case "l":
rawConfig.LocalPort = localPort
case "s":
rawConfig.RemoteHost = remoteHost
case "p":
rawConfig.RemotePort = remotePort
case "u":
rawConfig.UDP = udp
case "proxy":
rawConfig.ProxyMethod = proxyMethod
}
})
// ones with default values
if rawConfig.LocalHost == "" {
rawConfig.LocalHost = localHost
}
if rawConfig.LocalPort == "" {
rawConfig.LocalPort = localPort
}
if rawConfig.RemotePort == "" {
rawConfig.RemotePort = remotePort
}
}
localConfig, remoteConfig, authInfo, err := rawConfig.ProcessRawConfig(common.RealWorldState)
if err != nil {
log.Fatal(err)
}
var adminUID []byte
if b64AdminUID != "" {
adminUID, err = base64.StdEncoding.DecodeString(b64AdminUID)
if err != nil {
log.Fatal(err)
}
}
var seshMaker func() *mux.Session
d := &net.Dialer{Control: protector, KeepAlive: remoteConfig.KeepAlive}
if adminUID != nil {
log.Infof("API base is %v", localConfig.LocalAddr)
authInfo.UID = adminUID
authInfo.SessionId = 0
remoteConfig.NumConn = 1
seshMaker = func() *mux.Session {
return client.MakeSession(remoteConfig, authInfo, d)
}
} else {
var network string
if authInfo.Unordered {
network = "UDP"
} else {
network = "TCP"
}
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
seshMaker = func() *mux.Session {
authInfo := authInfo // copy the struct because we are overwriting SessionId
randByte := make([]byte, 1)
common.RandRead(authInfo.WorldState.Rand, randByte)
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.
quad := make([]byte, 4)
common.RandRead(authInfo.WorldState.Rand, quad)
authInfo.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(remoteConfig, authInfo, d)
}
}
rand.Seed(time.Now().UnixNano())
sessionID := rand.Uint32()
if authInfo.Unordered {
acceptor := func() (*net.UDPConn, error) {
udpAddr, _ := net.ResolveUDPAddr("udp", localConfig.LocalAddr)
return net.ListenUDP("udp", udpAddr)
}
client.RouteUDP(acceptor, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
} else {
listener, err := net.Listen("tcp", localConfig.LocalAddr)
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, sessionID)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatal(err)
}
client.RouteTCP(listener, localConfig.Timeout, remoteConfig.Singleplex, seshMaker)
if sta.SS_LOCAL_PORT == "" {
log.Fatal("Must specify localPort")
}
if sta.SS_REMOTE_HOST == "" {
log.Fatal("Must specify remoteHost")
}
if sta.TicketTimeHint == 0 {
log.Fatal("TicketTimeHint cannot be empty or 0")
}
listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
log.Println("Listening for ss on " + sta.SS_LOCAL_HOST + ":" + sta.SS_LOCAL_PORT)
if err != nil {
log.Fatal(err)
}
start:
var UNLIMITED int64 = 1e12
valve := mux.MakeValve(1e12, 1e12, &UNLIMITED, &UNLIMITED)
obfs := mux.MakeObfs(sta.UID)
deobfs := mux.MakeDeobfs(sta.UID)
sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS)
var wg sync.WaitGroup
for i := 0; i < sta.NumConn; i++ {
wg.Add(1)
go func() {
makeconn:
conn, err := makeRemoteConn(sta)
if err != nil {
log.Printf("Failed to establish new connections to remote: %v\n", err)
time.Sleep(time.Second * 3)
goto makeconn
}
sesh.AddConnection(conn)
wg.Done()
}()
}
wg.Wait()
var broken uint32
for {
if atomic.LoadUint32(&broken) == 1 {
goto retry
}
ssConn, err := listener.Accept()
if err != nil {
log.Println(err)
continue
}
go func() {
data := make([]byte, 10240)
i, err := io.ReadAtLeast(ssConn, data, 1)
if err != nil {
log.Println(err)
ssConn.Close()
return
}
stream, err := sesh.OpenStream()
if err != nil {
if err == mux.ErrBrokenSession {
atomic.StoreUint32(&broken, 1)
}
log.Println(err)
ssConn.Close()
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Println(err)
ssConn.Close()
stream.Close()
return
}
go pipe(ssConn, stream)
pipe(stream, ssConn)
}()
}
retry:
time.Sleep(time.Second * 3)
log.Println("Reconnecting")
goto start
}

View File

@ -1,4 +1,3 @@
//go:build !android
// +build !android
package main

View File

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build android
// +build android
package main
@ -29,10 +28,9 @@ import "C"
import (
"bufio"
"log"
"os"
"unsafe"
log "github.com/sirupsen/logrus"
)
var (
@ -68,6 +66,8 @@ func lineLog(f *os.File, priority C.int) {
func log_init() {
log.SetOutput(infoWriter{})
// android logcat includes all of log.LstdFlags
log.SetFlags(log.Flags() &^ log.LstdFlags)
r, w, err := os.Pipe()
if err != nil {

View File

@ -1,4 +1,3 @@
//go:build !android
// +build !android
package main

View File

@ -1,6 +1,4 @@
//go:build android
// +build android
package main
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go
@ -21,8 +19,9 @@ package main
int fd[n]; \
}
int ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds,
void *buffer) {
int
ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds, void *buffer)
{
struct msghdr msghdr;
char nothing = '!';
struct iovec nothing_ptr;
@ -47,28 +46,29 @@ int ancil_send_fds_with_buffer(int sock, const int *fds, unsigned n_fds,
return(sendmsg(sock, &msghdr, 0) >= 0 ? 0 : -1);
}
int ancil_send_fd(int sock, int fd) {
int
ancil_send_fd(int sock, int fd)
{
ANCIL_FD_BUFFER(1) buffer;
return(ancil_send_fds_with_buffer(sock, &fd, 1, &buffer));
}
void set_timeout(int sock) {
void
set_timeout(int sock)
{
struct timeval tv;
tv.tv_sec = 3;
tv.tv_usec = 0;
setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv,
sizeof(struct timeval));
setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv,
sizeof(struct timeval));
setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval));
setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, sizeof(struct timeval));
}
*/
import "C"
import (
"log"
"syscall"
log "github.com/sirupsen/logrus"
)
// In Android, once an app starts the VpnService, all outgoing traffic are routed by the system

View File

@ -1,199 +1,289 @@
package main
import (
"bytes"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"runtime"
"strings"
"time"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/server/usermanager"
"github.com/cbeuw/Cloak/internal/util"
)
var version string
func resolveBindAddr(bindAddrs []string) ([]net.Addr, error) {
var addrs []net.Addr
for _, addr := range bindAddrs {
bindAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
// The maximum size of TLS message will be 16396+12. 12 because of the stream header
// 16408 is the max TLS message size on Firefox
buf := make([]byte, 16396)
for {
i, err := io.ReadAtLeast(src, buf, 1)
if err != nil || i == 0 {
go dst.Close()
go src.Close()
return
}
i, err = dst.Write(buf[:i])
if err != nil || i == 0 {
go dst.Close()
go src.Close()
return
}
addrs = append(addrs, bindAddr)
}
return addrs, nil
}
// parse what shadowsocks server wants us to bind and harmonise it with what's already in bindAddr from
// our own config's BindAddr. This prevents duplicate bindings etc.
func parseSSBindAddr(ssRemoteHost string, ssRemotePort string, ckBindAddr *[]net.Addr) error {
var ssBind string
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0
v4nv6 := len(strings.Split(ssRemoteHost, "|")) == 2
if v4nv6 {
ssBind = ":" + ssRemotePort
} else {
ssBind = net.JoinHostPort(ssRemoteHost, ssRemotePort)
}
ssBindAddr, err := net.ResolveTCPAddr("tcp", ssBind)
func dispatchConnection(conn net.Conn, sta *server.State) {
goWeb := func(data []byte) {
webConn, err := net.Dial("tcp", sta.WebServerAddr)
if err != nil {
return fmt.Errorf("unable to resolve bind address provided by SS: %v", err)
log.Printf("Making connection to redirection server: %v\n", err)
go webConn.Close()
return
}
webConn.Write(data)
go pipe(webConn, conn)
go pipe(conn, webConn)
}
shouldAppend := true
for i, addr := range *ckBindAddr {
if addr.String() == ssBindAddr.String() {
shouldAppend = false
buf := make([]byte, 1500)
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
i, err := io.ReadAtLeast(conn, buf, 1)
if err != nil {
go conn.Close()
return
}
if addr.String() == ":"+ssRemotePort { // already listening on all interfaces
shouldAppend = false
conn.SetReadDeadline(time.Time{})
data := buf[:i]
ch, err := server.ParseClientHello(data)
if err != nil {
log.Printf("+1 non SS non (or malformed) TLS traffic from %v\n", conn.RemoteAddr())
goWeb(data)
return
}
if addr.String() == "0.0.0.0:"+ssRemotePort || addr.String() == "[::]:"+ssRemotePort {
// if config listens on one ip version but ss wants to listen on both,
// listen on both
if ssBindAddr.String() == ":"+ssRemotePort {
shouldAppend = true
(*ckBindAddr)[i] = ssBindAddr
isSS, UID, sessionID := server.TouchStone(ch, sta)
if !isSS {
log.Printf("+1 non SS TLS traffic from %v\n", conn.RemoteAddr())
goWeb(data)
return
}
finishHandshake := func() error {
reply := server.ComposeReply(ch)
_, err = conn.Write(reply)
if err != nil {
go conn.Close()
return err
}
// Two discarded messages: ChangeCipherSpec and Finished
discardBuf := make([]byte, 1024)
for c := 0; c < 2; c++ {
_, err = util.ReadTLS(conn, discardBuf)
if err != nil {
go conn.Close()
return err
}
if shouldAppend {
*ckBindAddr = append(*ckBindAddr, ssBindAddr)
}
return nil
}
func main() {
var config string
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
err = finishHandshake()
if err != nil {
log.Println(err)
return
}
c := sta.Userpanel.MakeController(sta.AdminUID)
for {
n, err := conn.Read(data)
if err != nil {
log.Println(err)
return
}
resp, err := c.HandleRequest(data[:n])
if err != nil {
log.Println(err)
return
}
_, err = conn.Write(resp)
if err != nil {
log.Println(err)
return
}
}
var pluginMode bool
}
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
if os.Getenv("SS_LOCAL_HOST") != "" && os.Getenv("SS_LOCAL_PORT") != "" {
pluginMode = true
config = os.Getenv("SS_PLUGIN_OPTIONS")
var user *usermanager.User
if bytes.Equal(UID, sta.AdminUID) {
user, err = sta.Userpanel.GetAndActivateAdminUser(UID)
} else {
flag.StringVar(&config, "c", "server.json", "config: path to the configuration file or its content")
user, err = sta.Userpanel.GetAndActivateUser(UID)
}
if err != nil {
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
goWeb(data)
return
}
err = finishHandshake()
if err != nil {
log.Println(err)
return
}
sesh, existing, err := user.GetSession(sessionID, mux.MakeObfs(UID), mux.MakeDeobfs(UID), util.ReadTLS)
if err != nil {
user.DelSession(sessionID)
log.Println(err)
return
}
if existing {
sesh.AddConnection(conn)
return
} else {
log.Printf("UID: %x\n", UID)
sesh.AddConnection(conn)
for {
newStream, err := sesh.AcceptStream()
if err != nil {
log.Printf("Failed to get new stream: %v\n", err)
if err == mux.ErrBrokenSession {
log.Printf("Session closed: %x:%v\n", UID, sessionID)
user.DelSession(sessionID)
return
} else {
continue
}
}
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil {
log.Printf("Failed to connect to ssserver: %v\n", err)
continue
}
go pipe(ssConn, newStream)
go pipe(newStream, ssConn)
}
}
}
func main() {
runtime.SetBlockProfileRate(5)
go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}()
// Should be 127.0.0.1 to listen to ss-server on this machine
var localHost string
// server_port in ss config, same as remotePort in plugin mode
var localPort string
// server in ss config, the outbound listening ip
var remoteHost string
// Outbound listening ip, should be 443
var remotePort string
var pluginOpts string
log.SetFlags(log.LstdFlags | log.Lshortfile)
if os.Getenv("SS_LOCAL_HOST") != "" {
localHost = os.Getenv("SS_LOCAL_HOST")
localPort = os.Getenv("SS_LOCAL_PORT")
remoteHost = os.Getenv("SS_REMOTE_HOST")
remotePort = os.Getenv("SS_REMOTE_PORT")
pluginOpts = os.Getenv("SS_PLUGIN_OPTIONS")
} else {
localAddr := flag.String("r", "", "localAddr: 127.0.0.1:server_port as set in SS config")
flag.StringVar(&remoteHost, "s", "0.0.0.0", "remoteHost: outbound listing ip, set to 0.0.0.0 to listen to everything")
flag.StringVar(&remotePort, "p", "443", "remotePort: outbound listing port, should be 443")
flag.StringVar(&pluginOpts, "c", "server.json", "pluginOpts: path to server.json or options seperated by semicolons")
askVersion := flag.Bool("v", false, "Print the version number")
printUsage := flag.Bool("h", false, "Print this message")
genUIDScript := flag.Bool("u", false, "Generate a UID to STDOUT")
genKeyPairScript := flag.Bool("k", false, "Generate a pair of public and private key and output to STDOUT in the format of <public key>,<private key>")
genUIDHuman := flag.Bool("uid", false, "Generate and print out a UID")
genKeyPairHuman := flag.Bool("key", false, "Generate and print out a public-private key pair")
pprofAddr := flag.String("d", "", "debug use: ip:port to be listened by pprof profiler")
verbosity := flag.String("verbosity", "info", "verbosity level")
genUID := flag.Bool("u", false, "Generate a UID")
genKeyPair := flag.Bool("k", false, "Generate a pair of public and private key, output in the format of pubkey,pvkey")
flag.Parse()
if *askVersion {
fmt.Printf("ck-server %s", version)
fmt.Printf("ck-server %s\n", version)
return
}
if *printUsage {
flag.Usage()
return
}
if *genUIDScript || *genUIDHuman {
uid := generateUID()
if *genUIDScript {
fmt.Println(uid)
} else {
fmt.Printf("\x1B[35mYour UID is:\u001B[0m %s\n", uid)
}
if *genUID {
fmt.Println(generateUID())
return
}
if *genKeyPairScript || *genKeyPairHuman {
if *genKeyPair {
pub, pv := generateKeyPair()
if *genKeyPairScript {
fmt.Printf("%v,%v\n", pub, pv)
} else {
fmt.Printf("\x1B[36mYour PUBLIC key is:\x1B[0m %65s\n", pub)
fmt.Printf("\x1B[33mYour PRIVATE key is (keep it secret):\x1B[0m %47s\n", pv)
}
fmt.Printf("%v,%v", pub, pv)
return
}
if *pprofAddr != "" {
runtime.SetBlockProfileRate(5)
go func() {
log.Info(http.ListenAndServe(*pprofAddr, nil))
}()
log.Infof("pprof listening on %v", *pprofAddr)
if *localAddr == "" {
log.Fatal("Must specify localAddr")
}
lvl, err := log.ParseLevel(*verbosity)
if err != nil {
log.Fatal(err)
localHost = strings.Split(*localAddr, ":")[0]
localPort = strings.Split(*localAddr, ":")[1]
log.Printf("Starting standalone mode, listening on %v:%v to ss at %v:%v\n", remoteHost, remotePort, localHost, localPort)
}
log.SetLevel(lvl)
sta, _ := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
log.Infof("Starting standalone mode")
}
raw, err := server.ParseConfig(config)
err := sta.ParseConfig(pluginOpts)
if err != nil {
log.Fatalf("Configuration file error: %v", err)
}
bindAddr, err := resolveBindAddr(raw.BindAddr)
if err != nil {
log.Fatalf("unable to parse BindAddr: %v", err)
if sta.AdminUID == nil {
log.Fatalln("AdminUID cannot be empty!")
}
// in case the user hasn't specified any local address to bind to, we listen on 443 and 80
if !pluginMode && len(bindAddr) == 0 {
https, _ := net.ResolveTCPAddr("tcp", ":443")
http, _ := net.ResolveTCPAddr("tcp", ":80")
bindAddr = []net.Addr{https, http}
}
go sta.UsedRandomCleaner()
// when cloak is started as a shadowsocks plugin, we parse the address ss-server
// is listening on into ProxyBook, and we parse the list of bindAddr
if pluginMode {
ssLocalHost := os.Getenv("SS_LOCAL_HOST")
ssLocalPort := os.Getenv("SS_LOCAL_PORT")
raw.ProxyBook["shadowsocks"] = []string{"tcp", net.JoinHostPort(ssLocalHost, ssLocalPort)}
ssRemoteHost := os.Getenv("SS_REMOTE_HOST")
ssRemotePort := os.Getenv("SS_REMOTE_PORT")
err = parseSSBindAddr(ssRemoteHost, ssRemotePort, &bindAddr)
if err != nil {
log.Fatalf("failed to parse SS_REMOTE_HOST and SS_REMOTE_PORT: %v", err)
}
}
sta, err := server.InitState(raw, common.RealWorldState)
if err != nil {
log.Fatalf("unable to initialise server state: %v", err)
}
listen := func(bindAddr net.Addr) {
listener, err := net.Listen("tcp", bindAddr.String())
log.Infof("Listening on %v", bindAddr)
listen := func(addr, port string) {
listener, err := net.Listen("tcp", addr+":"+port)
log.Println("Listening on " + addr + ":" + port)
if err != nil {
log.Fatal(err)
}
server.Serve(listener, sta)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("%v", err)
continue
}
go dispatchConnection(conn, sta)
}
}
for i, addr := range bindAddr {
if i != len(bindAddr)-1 {
go listen(addr)
// When listening on an IPv6 and IPv4, SS gives REMOTE_HOST as e.g. ::|0.0.0.0
listeningIP := strings.Split(sta.SS_REMOTE_HOST, "|")
for i, ip := range listeningIP {
if net.ParseIP(ip).To4() == nil {
// IPv6 needs square brackets
ip = "[" + ip + "]"
}
// The last listener must block main() because the program exits on main return.
if i == len(listeningIP)-1 {
listen(ip, sta.SS_REMOTE_PORT)
} else {
// we block the main goroutine here so it doesn't quit
listen(addr)
go listen(ip, sta.SS_REMOTE_PORT)
}
}

View File

@ -1,136 +0,0 @@
package main
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseBindAddr(t *testing.T) {
t.Run("port only", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":443"})
assert.NoError(t, err)
assert.Equal(t, ":443", addrs[0].String())
})
t.Run("specific address", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
assert.NoError(t, err)
assert.Equal(t, "192.168.1.123:443", addrs[0].String())
})
t.Run("ipv6", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"[::]:443"})
assert.NoError(t, err)
assert.Equal(t, "[::]:443", addrs[0].String())
})
t.Run("mixed", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
assert.NoError(t, err)
assert.Equal(t, ":80", addrs[0].String())
assert.Equal(t, "[::]:443", addrs[1].String())
})
}
func assertSetEqual(t *testing.T, list1, list2 interface{}, msgAndArgs ...interface{}) (ok bool) {
return assert.Subset(t, list1, list2, msgAndArgs) && assert.Subset(t, list2, list1, msgAndArgs)
}
func TestParseSSBindAddr(t *testing.T) {
testTable := []struct {
name string
ssRemoteHost string
ssRemotePort string
ckBindAddr []net.Addr
expectedAddr []net.Addr
}{
{
"ss only ipv4",
"127.0.0.1",
"443",
[]net.Addr{},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 443,
},
},
},
{
"ss only ipv6",
"::",
"443",
[]net.Addr{},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
},
//{
// "ss only ipv4 and v6",
// "::|127.0.0.1",
// "443",
// []net.Addr{},
// []net.Addr{
// &net.TCPAddr{
// IP: net.ParseIP("::"),
// Port: 443,
// },
// &net.TCPAddr{
// IP: net.ParseIP("127.0.0.1"),
// Port: 443,
// },
// },
//},
{
"ss and existing agrees",
"::",
"443",
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
},
{
"ss adds onto existing",
"127.0.0.1",
"80",
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
},
[]net.Addr{
&net.TCPAddr{
IP: net.ParseIP("::"),
Port: 443,
},
&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 80,
},
},
},
}
for _, test := range testTable {
test := test
t.Run(test.name, func(t *testing.T) {
assert.NoError(t, parseSSBindAddr(test.ssRemoteHost, test.ssRemotePort, &test.ckBindAddr))
assertSetEqual(t, test.ckBindAddr, test.expectedAddr)
})
}
}

View File

@ -3,20 +3,21 @@ package main
import (
"crypto/rand"
"encoding/base64"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
ecdh "github.com/cbeuw/go-ecdh"
)
var b64 = base64.StdEncoding.EncodeToString
func generateUID() string {
UID := make([]byte, 16)
common.CryptoRandRead(UID)
return base64.StdEncoding.EncodeToString(UID)
UID := make([]byte, 32)
rand.Read(UID)
return b64(UID)
}
func generateKeyPair() (string, string) {
staticPv, staticPub, _ := ecdh.GenerateKey(rand.Reader)
marshPub := ecdh.Marshal(staticPub)
ec := ecdh.NewCurve25519ECDH()
staticPv, staticPub, _ := ec.GenerateKey(rand.Reader)
marshPub := ec.Marshal(staticPub)
marshPv := staticPv.(*[32]byte)[:]
return base64.StdEncoding.EncodeToString(marshPub), base64.StdEncoding.EncodeToString(marshPv)
return b64(marshPub), b64(marshPv)
}

View File

@ -1,4 +0,0 @@
coverage:
status:
project: off
patch: off

8
config/ckclient.json Normal file
View File

@ -0,0 +1,8 @@
{
"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=",
"PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=",
"ServerName":"www.bing.com",
"TicketTimeHint":3600,
"NumConn":4,
"MaskBrowser":"chrome"
}

7
config/ckserver.json Normal file
View File

@ -0,0 +1,7 @@
{
"WebServerAddr":"204.79.197.200:443",
"PrivateKey":"EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=",
"AdminUID":"ugDmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=",
"DatabasePath":"userinfo.db",
"BackupDirPath":""
}

View File

@ -1,11 +0,0 @@
{
"Transport": "direct",
"ProxyMethod": "shadowsocks",
"EncryptionMethod": "plain",
"UID": "---Your UID here---",
"PublicKey": "---Public key here---",
"ServerName": "www.bing.com",
"NumConn": 4,
"BrowserSig": "chrome",
"StreamTimeout": 300
}

View File

@ -1,27 +0,0 @@
{
"ProxyBook": {
"shadowsocks": [
"tcp",
"127.0.0.1:8388"
],
"openvpn": [
"udp",
"127.0.0.1:8389"
],
"tor": [
"tcp",
"127.0.0.1:9001"
]
},
"BindAddr": [
":443",
":80"
],
"BypassUID": [
"---Bypass UID here---"
],
"RedirAddr": "cloudflare.com",
"PrivateKey": "---Private key here---",
"AdminUID": "---Admin UID here (optional)---",
"DatabasePath": "userinfo.db"
}

30
go.mod
View File

@ -1,30 +0,0 @@
module github.com/cbeuw/Cloak
go 1.24.0
toolchain go1.24.2
require (
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.3
github.com/juju/ratelimit v1.0.2
github.com/refraction-networking/utls v1.8.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.10.0
go.etcd.io/bbolt v1.4.0
golang.org/x/crypto v0.37.0
)
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
golang.org/x/sys v0.32.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

61
go.sum
View File

@ -1,61 +0,0 @@
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3 h1:LRxW8pdmWmyhoNh+TxUjxsAinGtCsVGjsl3xg6zoRSs=
github.com/cbeuw/connutil v0.0.0-20200411215123-966bfaa51ee3/go.mod h1:6jR2SzckGv8hIIS9zWJ160mzGVVOYp4AXZMDtacL6LE=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI=
github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.6.6 h1:igFsYBUJPYM8Rno9xUuDoM5GQrVEqY4llzEXOkL43Ig=
github.com/refraction-networking/utls v1.6.6/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0=
github.com/refraction-networking/utls v1.7.0/go.mod h1:lV0Gwc1/Fi+HYH8hOtgFRdHfKo4FKSn6+FdyOz9hRms=
github.com/refraction-networking/utls v1.7.3 h1:L0WRhHY7Oq1T0zkdzVZMR6zWZv+sXbHB9zcuvsAEqCo=
github.com/refraction-networking/utls v1.7.3/go.mod h1:TUhh27RHMGtQvjQq+RyO11P6ZNQNBb3N0v7wsEjKAIQ=
github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE=
github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk=
go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,168 +0,0 @@
package client
import (
"github.com/cbeuw/Cloak/internal/common"
utls "github.com/refraction-networking/utls"
log "github.com/sirupsen/logrus"
"net"
"strings"
)
const appDataMaxLength = 16401
type clientHelloFields struct {
random []byte
sessionId []byte
x25519KeyShare []byte
serverName string
}
type browser int
const (
chrome = iota
firefox
safari
)
type DirectTLS struct {
*common.TLSConn
browser browser
}
var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
func randomServerName() string {
/*
Copyright: Proton AG
https://github.com/ProtonVPN/wireguard-go/commit/bcf344b39b213c1f32147851af0d2a8da9266883
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
charNum := int('z') - int('a') + 1
size := 3 + common.RandInt(10)
name := make([]byte, size)
for i := range name {
name[i] = byte(int('a') + common.RandInt(charNum))
}
return string(name) + "." + common.RandItem(topLevelDomains)
}
func buildClientHello(browser browser, fields clientHelloFields) ([]byte, error) {
// We don't use utls to handle connections (as it'll attempt a real TLS negotiation)
// We only want it to build the ClientHello locally
fakeConn := net.TCPConn{}
var helloID utls.ClientHelloID
switch browser {
case chrome:
helloID = utls.HelloChrome_Auto
case firefox:
helloID = utls.HelloFirefox_Auto
case safari:
helloID = utls.HelloSafari_Auto
}
uclient := utls.UClient(&fakeConn, &utls.Config{ServerName: fields.serverName}, helloID)
if err := uclient.BuildHandshakeState(); err != nil {
return []byte{}, err
}
if err := uclient.SetClientRandom(fields.random); err != nil {
return []byte{}, err
}
uclient.HandshakeState.Hello.SessionId = make([]byte, 32)
copy(uclient.HandshakeState.Hello.SessionId, fields.sessionId)
// Find the X25519 key share and overwrite it
var extIndex int
var keyShareIndex int
for i, ext := range uclient.Extensions {
ext, ok := ext.(*utls.KeyShareExtension)
if ok {
extIndex = i
for j, keyShare := range ext.KeyShares {
if keyShare.Group == utls.X25519 {
keyShareIndex = j
}
}
}
}
copy(uclient.Extensions[extIndex].(*utls.KeyShareExtension).KeyShares[keyShareIndex].Data, fields.x25519KeyShare)
if err := uclient.BuildHandshakeState(); err != nil {
return []byte{}, err
}
return uclient.HandshakeState.Hello.Raw, nil
}
// Handshake handles the TLS handshake for a given conn and returns the sessionKey
// if the server proceed with Cloak authentication
func (tls *DirectTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) {
payload, sharedSecret := makeAuthenticationPayload(authInfo)
fields := clientHelloFields{
random: payload.randPubKey[:],
sessionId: payload.ciphertextWithTag[0:32],
x25519KeyShare: payload.ciphertextWithTag[32:64],
serverName: authInfo.MockDomain,
}
if strings.EqualFold(fields.serverName, "random") {
fields.serverName = randomServerName()
}
var ch []byte
ch, err = buildClientHello(tls.browser, fields)
if err != nil {
return
}
chWithRecordLayer := common.AddRecordLayer(ch, common.Handshake, common.VersionTLS11)
_, err = rawConn.Write(chWithRecordLayer)
if err != nil {
return
}
log.Trace("client hello sent successfully")
tls.TLSConn = common.NewTLSConn(rawConn)
buf := make([]byte, 1024)
log.Trace("waiting for ServerHello")
_, err = tls.Read(buf)
if err != nil {
return
}
encrypted := append(buf[6:38], buf[84:116]...)
nonce := encrypted[0:12]
ciphertextWithTag := encrypted[12:60]
sessionKeySlice, err := common.AESGCMDecrypt(nonce, sharedSecret[:], ciphertextWithTag)
if err != nil {
return
}
copy(sessionKey[:], sessionKeySlice)
for i := 0; i < 2; i++ {
// ChangeCipherSpec and EncryptedCert (in the format of application data)
_, err = tls.Read(buf)
if err != nil {
return
}
}
return sessionKey, nil
}

View File

@ -0,0 +1,70 @@
package TLS
import (
"encoding/binary"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/util"
"time"
)
type browser interface {
composeExtensions()
composeClientHello()
}
func makeServerName(sta *client.State) []byte {
serverName := sta.ServerName
serverNameListLength := make([]byte, 2)
binary.BigEndian.PutUint16(serverNameListLength, uint16(len(serverName)+3))
serverNameType := []byte{0x00} // host_name
serverNameLength := make([]byte, 2)
binary.BigEndian.PutUint16(serverNameLength, uint16(len(serverName)))
ret := make([]byte, 2+1+2+len(serverName))
copy(ret[0:2], serverNameListLength)
copy(ret[2:3], serverNameType)
copy(ret[3:5], serverNameLength)
copy(ret[5:], serverName)
return ret
}
func makeNullBytes(length int) []byte {
ret := make([]byte, length)
for i := 0; i < length; i++ {
ret[i] = 0x00
}
return ret
}
// addExtensionRecord, add type, length to extension data
func addExtRec(typ []byte, data []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(data)))
ret := make([]byte, 2+2+len(data))
copy(ret[0:2], typ)
copy(ret[2:4], length)
copy(ret[4:], data)
return ret
}
// ComposeInitHandshake composes ClientHello with record layer
func ComposeInitHandshake(sta *client.State) []byte {
var ch []byte
switch sta.MaskBrowser {
case "chrome":
ch = (&chrome{}).composeClientHello(sta)
case "firefox":
ch = (&firefox{}).composeClientHello(sta)
default:
panic("Unsupported browser:" + sta.MaskBrowser)
}
return util.AddRecordLayer(ch, []byte{0x16}, []byte{0x03, 0x01})
}
// ComposeReply composes RL+ChangeCipherSpec+RL+Finished
func ComposeReply() []byte {
TLS12 := []byte{0x03, 0x03}
ccsBytes := util.AddRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
finished := util.PsudoRandBytes(40, time.Now().UnixNano())
fBytes := util.AddRecordLayer(finished, []byte{0x16}, TLS12)
return append(ccsBytes, fBytes...)
}

View File

@ -0,0 +1,82 @@
// Chrome 64
package TLS
import (
"encoding/hex"
"math/rand"
"time"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/util"
)
type chrome struct {
browser
}
func (c *chrome) composeExtensions(sta *client.State) []byte {
// see https://tools.ietf.org/html/draft-davidben-tls-grease-01
// This is exclusive to chrome.
makeGREASE := func() []byte {
rand.Seed(time.Now().UnixNano())
sixteenth := rand.Intn(16)
monoGREASE := byte(sixteenth*16 + 0xA)
doubleGREASE := []byte{monoGREASE, monoGREASE}
return doubleGREASE
}
makeSupportedGroups := func() []byte {
suppGroupListLen := []byte{0x00, 0x08}
ret := make([]byte, 2+8)
copy(ret[0:2], suppGroupListLen)
copy(ret[2:4], makeGREASE())
copy(ret[4:], []byte{0x00, 0x1d, 0x00, 0x17, 0x00, 0x18})
return ret
}
var ext [14][]byte
ext[0] = addExtRec(makeGREASE(), nil) // First GREASE
ext[1] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
ext[2] = addExtRec([]byte{0x00, 0x00}, makeServerName(sta)) // server name indication
ext[3] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
ext[4] = addExtRec([]byte{0x00, 0x23}, client.MakeSessionTicket(sta)) // Session tickets
sigAlgo, _ := hex.DecodeString("0012040308040401050308050501080606010201")
ext[5] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
ext[6] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
ext[7] = addExtRec([]byte{0x00, 0x12}, nil) // signed cert timestamp
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
ext[8] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
ext[9] = addExtRec([]byte{0x75, 0x50}, nil) // channel id
ext[10] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
ext[11] = addExtRec([]byte{0x00, 0x0a}, makeSupportedGroups()) // supported groups
ext[12] = addExtRec(makeGREASE(), []byte{0x00}) // Last GREASE
ext[13] = addExtRec([]byte{0x00, 0x15}, makeNullBytes(110-len(ext[2]))) // padding
var ret []byte
for i := 0; i < 14; i++ {
ret = append(ret, ext[i]...)
}
return ret
}
func (c *chrome) composeClientHello(sta *client.State) []byte {
var clientHello [12][]byte
clientHello[0] = []byte{0x01} // handshake type
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
clientHello[2] = []byte{0x03, 0x03} // client version
clientHello[3] = client.MakeRandomField(sta) // random
clientHello[4] = []byte{0x20} // session id length 32
clientHello[5] = util.PsudoRandBytes(32, sta.Now().UnixNano()) // session id
clientHello[6] = []byte{0x00, 0x1c} // cipher suites length 28
cipherSuites, _ := hex.DecodeString("2a2ac02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a")
clientHello[7] = cipherSuites // cipher suites
clientHello[8] = []byte{0x01} // compression methods length 1
clientHello[9] = []byte{0x00} // compression methods
clientHello[10] = []byte{0x01, 0x97} // extensions length 407
clientHello[11] = c.composeExtensions(sta) // extensions
var ret []byte
for i := 0; i < 12; i++ {
ret = append(ret, clientHello[i]...)
}
return ret
}

View File

@ -0,0 +1,57 @@
// Firefox 58
package TLS
import (
"encoding/hex"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/util"
)
type firefox struct {
browser
}
func (f *firefox) composeExtensions(sta *client.State) []byte {
var ext [10][]byte
ext[0] = addExtRec([]byte{0x00, 0x00}, makeServerName(sta)) // server name indication
ext[1] = addExtRec([]byte{0x00, 0x17}, nil) // extended_master_secret
ext[2] = addExtRec([]byte{0xff, 0x01}, []byte{0x00}) // renegotiation_info
suppGroup, _ := hex.DecodeString("0008001d001700180019")
ext[3] = addExtRec([]byte{0x00, 0x0a}, suppGroup) // supported groups
ext[4] = addExtRec([]byte{0x00, 0x0b}, []byte{0x01, 0x00}) // ec point formats
ext[5] = addExtRec([]byte{0x00, 0x23}, client.MakeSessionTicket(sta)) // Session tickets
APLN, _ := hex.DecodeString("000c02683208687474702f312e31")
ext[6] = addExtRec([]byte{0x00, 0x10}, APLN) // app layer proto negotiation
ext[7] = addExtRec([]byte{0x00, 0x05}, []byte{0x01, 0x00, 0x00, 0x00, 0x00}) // status request
sigAlgo, _ := hex.DecodeString("001604030503060308040805080604010501060102030201")
ext[8] = addExtRec([]byte{0x00, 0x0d}, sigAlgo) // Signature Algorithms
ext[9] = addExtRec([]byte{0x00, 0x15}, makeNullBytes(121-len(ext[0]))) // padding
var ret []byte
for i := 0; i < 10; i++ {
ret = append(ret, ext[i]...)
}
return ret
}
func (f *firefox) composeClientHello(sta *client.State) []byte {
var clientHello [12][]byte
clientHello[0] = []byte{0x01} // handshake type
clientHello[1] = []byte{0x00, 0x01, 0xfc} // length 508
clientHello[2] = []byte{0x03, 0x03} // client version
clientHello[3] = client.MakeRandomField(sta) // random
clientHello[4] = []byte{0x20} // session id length 32
clientHello[5] = util.PsudoRandBytes(32, sta.Now().UnixNano()) // session id
clientHello[6] = []byte{0x00, 0x1e} // cipher suites length 28
cipherSuites, _ := hex.DecodeString("c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a")
clientHello[7] = cipherSuites // cipher suites
clientHello[8] = []byte{0x01} // compression methods length 1
clientHello[9] = []byte{0x00} // compression methods
clientHello[10] = []byte{0x01, 0x95} // extensions length 405
clientHello[11] = f.composeExtensions(sta) // extensions
var ret []byte
for i := 0; i < 12; i++ {
ret = append(ret, clientHello[i]...)
}
return ret
}

View File

@ -1,56 +1,70 @@
package client
import (
"crypto"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/util"
ecdh "github.com/cbeuw/go-ecdh"
"io"
)
const (
UNORDERED_FLAG = 0x01 // 0000 0001
)
type authenticationPayload struct {
randPubKey [32]byte
ciphertextWithTag [64]byte
type keyPair struct {
crypto.PrivateKey
crypto.PublicKey
}
// makeAuthenticationPayload generates the ephemeral key pair, calculates the shared secret, and then compose and
// encrypt the authenticationPayload
func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sharedSecret [32]byte) {
/*
Authentication data:
+----------+----------------+---------------------+-------------+--------------+--------+------------+
| _UID_ | _Proxy Method_ | _Encryption Method_ | _Timestamp_ | _Session Id_ | _Flag_ | _reserved_ |
+----------+----------------+---------------------+-------------+--------------+--------+------------+
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
+----------+----------------+---------------------+-------------+--------------+--------+------------+
*/
ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand)
if err != nil {
log.Panicf("failed to generate ephemeral key pair: %v", err)
}
copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
plaintext := make([]byte, 48)
copy(plaintext, authInfo.UID)
copy(plaintext[16:28], authInfo.ProxyMethod)
plaintext[28] = authInfo.EncryptionMethod
binary.BigEndian.PutUint64(plaintext[29:37], uint64(authInfo.WorldState.Now().UTC().Unix()))
binary.BigEndian.PutUint32(plaintext[37:41], authInfo.SessionId)
if authInfo.Unordered {
plaintext[41] |= UNORDERED_FLAG
func MakeRandomField(sta *State) []byte {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(sta.Now().Unix()/(12*60*60)))
rdm := make([]byte, 16)
io.ReadFull(rand.Reader, rdm)
preHash := make([]byte, 56)
copy(preHash[0:32], sta.UID)
copy(preHash[32:40], t)
copy(preHash[40:56], rdm)
h := sha256.New()
h.Write(preHash)
ret := make([]byte, 32)
copy(ret[0:16], rdm)
copy(ret[16:32], h.Sum(nil)[0:16])
return ret
}
secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)
if err != nil {
log.Panicf("error in generating shared secret: %v", err)
func MakeSessionTicket(sta *State) []byte {
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted UID+sessionID 36 bytes][padding 124 bytes]
// The first 16 bytes of the marshalled ephemeral public key is used as the IV
// for encrypting the UID
tthInterval := sta.Now().Unix() / int64(sta.TicketTimeHint)
ec := ecdh.NewCurve25519ECDH()
ephKP := sta.getKeyPair(tthInterval)
if ephKP == nil {
ephPv, ephPub, _ := ec.GenerateKey(rand.Reader)
ephKP = &keyPair{
ephPv,
ephPub,
}
copy(sharedSecret[:], secret)
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext)
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:])
return
sta.putKeyPair(tthInterval, ephKP)
}
ticket := make([]byte, 192)
copy(ticket[0:32], ec.Marshal(ephKP.PublicKey))
key, _ := ec.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub)
plainUIDsID := make([]byte, 36)
copy(plainUIDsID, sta.UID)
binary.BigEndian.PutUint32(plainUIDsID[32:36], sta.sessionID)
cipherUIDsID := util.AESEncrypt(ticket[0:16], key, plainUIDsID)
copy(ticket[32:68], cipherUIDsID)
// The purpose of adding sessionID is that, the generated padding of sessionTicket needs to be unpredictable.
// As shown in auth.go, the padding is generated by a psudo random generator. The seed
// needs to be the same for each TicketTimeHint interval. However the value of epoch/TicketTimeHint
// is public knowledge, so is the psudo random algorithm used by math/rand. Therefore not only
// can the firewall tell that the padding is generated in this specific way, this padding is identical
// for all ckclients in the same TicketTimeHint interval. This will expose us.
//
// With the sessionID value generated at startup of ckclient and used as a part of the seed, the
// sessionTicket is still identical for each TicketTimeHint interval, but others won't be able to know
// how it was generated. It will also be different for each client.
copy(ticket[68:192], util.PsudoRandBytes(124, tthInterval+int64(sta.sessionID)))
return ticket
}

View File

@ -1,73 +0,0 @@
package client
import (
"bytes"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/multiplex"
"github.com/stretchr/testify/assert"
)
func TestMakeAuthenticationPayload(t *testing.T) {
tests := []struct {
authInfo AuthInfo
expPayload authenticationPayload
expSecret [32]byte
}{
{
AuthInfo{
Unordered: false,
SessionId: 3421516597,
UID: []byte{
0x4c, 0xd8, 0xcc, 0x15, 0x60, 0x0d, 0x7e,
0xb6, 0x81, 0x31, 0xfd, 0x80, 0x97, 0x67, 0x37, 0x46},
ServerPubKey: &[32]byte{
0x21, 0x8a, 0x14, 0xce, 0x49, 0x5e, 0xfd, 0x3f,
0xe4, 0xae, 0x21, 0x3e, 0x51, 0xf7, 0x66, 0xec,
0x01, 0xd0, 0xb4, 0x87, 0x86, 0x9c, 0x15, 0x9b,
0x86, 0x19, 0x53, 0x6e, 0x60, 0xe9, 0x51, 0x42},
ProxyMethod: "shadowsocks",
EncryptionMethod: multiplex.EncryptionMethodPlain,
MockDomain: "d2jkinvisak5y9.cloudfront.net",
WorldState: common.WorldState{
Rand: bytes.NewBuffer([]byte{
0xf1, 0x1e, 0x42, 0xe1, 0x84, 0x22, 0x07, 0xc5,
0xc3, 0x5c, 0x0f, 0x7b, 0x01, 0xf3, 0x65, 0x2d,
0xd7, 0x9b, 0xad, 0xb0, 0xb2, 0x77, 0xa2, 0x06,
0x6b, 0x78, 0x1b, 0x74, 0x1f, 0x43, 0xc9, 0x80}),
Now: func() time.Time { return time.Unix(1579908372, 0) },
},
},
authenticationPayload{
randPubKey: [32]byte{
0xee, 0x9e, 0x41, 0x4e, 0xb3, 0x3b, 0x85, 0x03,
0x6d, 0x85, 0xba, 0x30, 0x11, 0x31, 0x10, 0x24,
0x4f, 0x7b, 0xd5, 0x38, 0x50, 0x0f, 0xf2, 0x4d,
0xa3, 0xdf, 0xba, 0x76, 0x0a, 0xe9, 0x19, 0x19},
ciphertextWithTag: [64]byte{
0x71, 0xb1, 0x6c, 0x5a, 0x60, 0x46, 0x90, 0x12,
0x36, 0x3b, 0x1b, 0xc4, 0x79, 0x3c, 0xab, 0xdd,
0x5a, 0x53, 0xc5, 0xed, 0xaf, 0xdb, 0x10, 0x98,
0x83, 0x96, 0x81, 0xa6, 0xfc, 0xa2, 0x1e, 0xb0,
0x89, 0xb2, 0x29, 0x71, 0x7e, 0x45, 0x97, 0x54,
0x11, 0x7d, 0x9b, 0x92, 0xbb, 0xd6, 0xce, 0x37,
0x3b, 0xb8, 0x8b, 0xfb, 0xb6, 0x40, 0xf0, 0x2c,
0x6c, 0x55, 0xb9, 0xfc, 0x5d, 0x34, 0x89, 0x41},
},
[32]byte{
0xc7, 0xc6, 0x9b, 0xbe, 0xec, 0xf8, 0x35, 0x55,
0x67, 0x20, 0xcd, 0xeb, 0x74, 0x16, 0xc5, 0x60,
0xee, 0x9d, 0x63, 0x1a, 0x44, 0xc5, 0x09, 0xf6,
0xe0, 0x24, 0xad, 0xd2, 0x10, 0xe3, 0x4a, 0x11},
},
}
for _, tc := range tests {
func() {
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
}()
}
}

View File

@ -1,83 +0,0 @@
package client
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
// On different invocations to MakeSession, authInfo.SessionId MUST be different
func MakeSession(connConfig RemoteConnConfig, authInfo AuthInfo, dialer common.Dialer) *mux.Session {
log.Info("Attempting to start a new session")
connsCh := make(chan net.Conn, connConfig.NumConn)
var _sessionKey atomic.Value
var wg sync.WaitGroup
for i := 0; i < connConfig.NumConn; i++ {
wg.Add(1)
transportConfig := connConfig.Transport
go func() {
makeconn:
transportConn := transportConfig.CreateTransport()
remoteConn, err := dialer.Dial("tcp", connConfig.RemoteAddr)
if err != nil {
log.Errorf("Failed to establish new connections to remote: %v", err)
// TODO increase the interval if failed multiple times
time.Sleep(time.Second * 3)
goto makeconn
}
sk, err := transportConn.Handshake(remoteConn, authInfo)
if err != nil {
log.Errorf("Failed to prepare connection to remote: %v", err)
transportConn.Close()
// In Cloak v2.11.0, we've updated uTLS version and subsequently increased the first packet size for chrome above 1500
// https://github.com/cbeuw/Cloak/pull/306#issuecomment-2862728738. As a backwards compatibility feature, if we fail
// to connect using chrome signature, retry with firefox which has a smaller packet size.
if transportConfig.mode == "direct" && transportConfig.browser == chrome {
transportConfig.browser = firefox
log.Warnf("failed to connect with chrome signature, falling back to retry with firefox")
}
time.Sleep(time.Second * 3)
goto makeconn
}
// sessionKey given by each connection should be identical
_sessionKey.Store(sk)
connsCh <- transportConn
wg.Done()
}()
}
wg.Wait()
log.Debug("All underlying connections established")
sessionKey := _sessionKey.Load().([32]byte)
obfuscator, err := mux.MakeObfuscator(authInfo.EncryptionMethod, sessionKey)
if err != nil {
log.Fatal(err)
}
seshConfig := mux.SessionConfig{
Singleplex: connConfig.Singleplex,
Obfuscator: obfuscator,
Valve: nil,
Unordered: authInfo.Unordered,
MsgOnWireSizeLimit: appDataMaxLength,
}
sesh := mux.MakeSession(authInfo.SessionId, seshConfig)
for i := 0; i < connConfig.NumConn; i++ {
conn := <-connsCh
sesh.AddConnection(conn)
}
log.Infof("Session %v established", authInfo.SessionId)
return sesh
}

View File

@ -1,153 +0,0 @@
package client
import (
"io"
"net"
"sync"
"time"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
localConn, err := bindFunc()
if err != nil {
log.Fatal(err)
}
streams := make(map[string]*mux.Stream)
var streamsMutex sync.Mutex
data := make([]byte, 8192)
for {
i, addr, err := localConn.ReadFrom(data)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
continue
}
if !singleplex && (sesh == nil || sesh.IsClosed()) {
sesh = newSeshFunc()
}
streamsMutex.Lock()
stream, ok := streams[addr.String()]
if !ok {
if singleplex {
sesh = newSeshFunc()
}
stream, err = sesh.OpenStream()
if err != nil {
if singleplex {
sesh.Close()
}
log.Errorf("Failed to open stream: %v", err)
streamsMutex.Unlock()
continue
}
streams[addr.String()] = stream
streamsMutex.Unlock()
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
proxyAddr := addr
go func(stream *mux.Stream, localConn *net.UDPConn) {
buf := make([]byte, 8192)
for {
n, err := stream.Read(buf)
if err != nil {
log.Tracef("copying stream to proxy client: %v", err)
break
}
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
_, err = localConn.WriteTo(buf[:n], proxyAddr)
if err != nil {
log.Tracef("copying stream to proxy client: %v", err)
break
}
}
streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close()
return
}(stream, localConn)
} else {
streamsMutex.Unlock()
}
_, err = stream.Write(data[:i])
if err != nil {
log.Tracef("copying proxy client to stream: %v", err)
streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close()
continue
}
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
}
}
func RouteTCP(listener net.Listener, streamTimeout time.Duration, singleplex bool, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
for {
localConn, err := listener.Accept()
if err != nil {
log.Fatal(err)
continue
}
if !singleplex && (sesh == nil || sesh.IsClosed()) {
sesh = newSeshFunc()
}
go func(sesh *mux.Session, localConn net.Conn, timeout time.Duration) {
if singleplex {
sesh = newSeshFunc()
}
data := make([]byte, 10240)
_ = localConn.SetReadDeadline(time.Now().Add(streamTimeout))
i, err := io.ReadAtLeast(localConn, data, 1)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
return
}
var zeroTime time.Time
_ = localConn.SetReadDeadline(zeroTime)
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
localConn.Close()
if singleplex {
sesh.Close()
}
return
}
_, err = stream.Write(data[:i])
if err != nil {
log.Errorf("Failed to write to stream: %v", err)
localConn.Close()
stream.Close()
return
}
go func() {
if _, err := common.Copy(localConn, stream); err != nil {
log.Tracef("copying stream to proxy client: %v", err)
}
}()
if _, err = common.Copy(stream, localConn); err != nil {
log.Tracef("copying proxy client to stream: %v", err)
}
}(sesh, localConn, streamTimeout)
}
}

View File

@ -2,88 +2,67 @@ package client
import (
"crypto"
"encoding/base64"
"encoding/json"
"fmt"
"errors"
"io/ioutil"
"net"
"strings"
"sync"
"time"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/ecdh"
mux "github.com/cbeuw/Cloak/internal/multiplex"
ecdh "github.com/cbeuw/go-ecdh"
)
// RawConfig represents the fields in the config json file
// nullable means if it's empty, a default value will be chosen in ProcessRawConfig
// jsonOptional means if the json's empty, its value will be set from environment variables or commandline args
// but it mustn't be empty when ProcessRawConfig is called
type RawConfig struct {
type rawConfig struct {
ServerName string
ProxyMethod string
EncryptionMethod string
UID []byte
PublicKey []byte
UID string
PublicKey string
TicketTimeHint int
MaskBrowser string
NumConn int
LocalHost string // jsonOptional
LocalPort string // jsonOptional
RemoteHost string // jsonOptional
RemotePort string // jsonOptional
AlternativeNames []string // jsonOptional
// defaults set in ProcessRawConfig
UDP bool // nullable
BrowserSig string // nullable
Transport string // nullable
CDNOriginHost string // nullable
CDNWsUrlPath string // nullable
StreamTimeout int // nullable
KeepAlive int // nullable
}
type RemoteConnConfig struct {
Singleplex bool
NumConn int
KeepAlive time.Duration
RemoteAddr string
Transport TransportConfig
}
// State stores global variables
type State struct {
SS_LOCAL_HOST string
SS_LOCAL_PORT string
SS_REMOTE_HOST string
SS_REMOTE_PORT string
type LocalConnConfig struct {
LocalAddr string
Timeout time.Duration
MockDomainList []string
}
type AuthInfo struct {
Now func() time.Time
sessionID uint32
UID []byte
SessionId uint32
ProxyMethod string
EncryptionMethod byte
Unordered bool
ServerPubKey crypto.PublicKey
MockDomain string
WorldState common.WorldState
staticPub crypto.PublicKey
keyPairsM sync.RWMutex
keyPairs map[int64]*keyPair
TicketTimeHint int
ServerName string
MaskBrowser string
NumConn int
}
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, sessionID uint32) *State {
ret := &State{
SS_LOCAL_HOST: localHost,
SS_LOCAL_PORT: localPort,
SS_REMOTE_HOST: remoteHost,
SS_REMOTE_PORT: remotePort,
Now: nowFunc,
sessionID: sessionID,
}
ret.keyPairs = make(map[int64]*keyPair)
return ret
}
// semi-colon separated value. This is for Android plugin options
func ssvToJson(ssv string) (ret []byte) {
elem := func(val string, lst []string) bool {
for _, v := range lst {
if val == v {
return true
}
}
return false
}
unescape := func(s string) string {
r := strings.Replace(s, `\\`, `\`, -1)
r = strings.Replace(r, `\=`, `=`, -1)
r = strings.Replace(r, `\;`, `;`, -1)
return r
}
unquoted := []string{"NumConn", "StreamTimeout", "KeepAlive", "UDP"}
lines := strings.Split(unescape(ssv), ";")
ret = []byte("{")
for _, ln := range lines {
@ -91,29 +70,11 @@ func ssvToJson(ssv string) (ret []byte) {
break
}
sp := strings.SplitN(ln, "=", 2)
if len(sp) < 2 {
log.Errorf("Malformed config option: %v", ln)
continue
}
key := sp[0]
value := sp[1]
if strings.HasPrefix(key, "AlternativeNames") {
switch strings.Contains(value, ",") {
case true:
domains := strings.Split(value, ",")
for index, domain := range domains {
domains[index] = `"` + domain + `"`
}
value = strings.Join(domains, ",")
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
case false:
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
}
continue
}
// JSON doesn't like quotation marks around int and bool
// This is extremely ugly but it's still better than writing a tokeniser
if elem(key, unquoted) {
// JSON doesn't like quotation marks around int
// Yes this is extremely ugly but it's still better than writing a tokeniser
if key == "TicketTimeHint" || key == "NumConn" {
ret = append(ret, []byte(`"`+key+`":`+value+`,`)...)
} else {
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
@ -124,156 +85,53 @@ func ssvToJson(ssv string) (ret []byte) {
return ret
}
func ParseConfig(conf string) (raw *RawConfig, err error) {
// ParseConfig parses the config (either a path to json or Android config) into a State variable
func (sta *State) ParseConfig(conf string) (err error) {
var content []byte
// Checking if it's a path to json or a ssv string
if strings.Contains(conf, ";") && strings.Contains(conf, "=") {
content = ssvToJson(conf)
} else {
content, err = ioutil.ReadFile(conf)
if err != nil {
return
return err
}
}
raw = new(RawConfig)
err = json.Unmarshal(content, &raw)
var preParse rawConfig
err = json.Unmarshal(content, &preParse)
if err != nil {
return
return err
}
return
sta.ServerName = preParse.ServerName
sta.TicketTimeHint = preParse.TicketTimeHint
sta.MaskBrowser = preParse.MaskBrowser
sta.NumConn = preParse.NumConn
uid, err := base64.StdEncoding.DecodeString(preParse.UID)
if err != nil {
return errors.New("Failed to parse UID: " + err.Error())
}
sta.UID = uid
func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
nullErr := func(field string) (local LocalConnConfig, remote RemoteConnConfig, auth AuthInfo, err error) {
err = fmt.Errorf("%v cannot be empty", field)
return
pubBytes, err := base64.StdEncoding.DecodeString(preParse.PublicKey)
if err != nil {
return errors.New("Failed to parse Public key: " + err.Error())
}
auth.UID = raw.UID
auth.Unordered = raw.UDP
if raw.ServerName == "" {
return nullErr("ServerName")
}
auth.MockDomain = raw.ServerName
var filteredAlternativeNames []string
for _, alternativeName := range raw.AlternativeNames {
if len(alternativeName) > 0 {
filteredAlternativeNames = append(filteredAlternativeNames, alternativeName)
}
}
raw.AlternativeNames = filteredAlternativeNames
local.MockDomainList = raw.AlternativeNames
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
if raw.ProxyMethod == "" {
return nullErr("ServerName")
}
auth.ProxyMethod = raw.ProxyMethod
if len(raw.UID) == 0 {
return nullErr("UID")
}
// static public key
if len(raw.PublicKey) == 0 {
return nullErr("PublicKey")
}
pub, ok := ecdh.Unmarshal(raw.PublicKey)
ec := ecdh.NewCurve25519ECDH()
pub, ok := ec.Unmarshal(pubBytes)
if !ok {
err = fmt.Errorf("failed to unmarshal Public key")
return
return errors.New("Failed to unmarshal Public key")
}
auth.ServerPubKey = pub
auth.WorldState = worldState
// Encryption method
switch strings.ToLower(raw.EncryptionMethod) {
case "plain":
auth.EncryptionMethod = mux.EncryptionMethodPlain
case "aes-gcm", "aes-256-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES256GCM
case "aes-128-gcm":
auth.EncryptionMethod = mux.EncryptionMethodAES128GCM
case "chacha20-poly1305":
auth.EncryptionMethod = mux.EncryptionMethodChaha20Poly1305
default:
err = fmt.Errorf("unknown encryption method %v", raw.EncryptionMethod)
return
sta.staticPub = pub
return nil
}
if raw.RemoteHost == "" {
return nullErr("RemoteHost")
}
if raw.RemotePort == "" {
return nullErr("RemotePort")
}
remote.RemoteAddr = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
if raw.NumConn <= 0 {
remote.NumConn = 1
remote.Singleplex = true
} else {
remote.NumConn = raw.NumConn
remote.Singleplex = false
func (sta *State) getKeyPair(tthInterval int64) *keyPair {
sta.keyPairsM.Lock()
defer sta.keyPairsM.Unlock()
return sta.keyPairs[tthInterval]
}
// Transport and (if TLS mode), browser
switch strings.ToLower(raw.Transport) {
case "cdn":
var cdnDomainPort string
if raw.CDNOriginHost == "" {
cdnDomainPort = net.JoinHostPort(raw.RemoteHost, raw.RemotePort)
} else {
cdnDomainPort = net.JoinHostPort(raw.CDNOriginHost, raw.RemotePort)
}
if raw.CDNWsUrlPath == "" {
raw.CDNWsUrlPath = "/"
}
remote.Transport = TransportConfig{
mode: "cdn",
wsUrl: "ws://" + cdnDomainPort + raw.CDNWsUrlPath,
}
case "direct":
fallthrough
default:
var browser browser
switch strings.ToLower(raw.BrowserSig) {
case "firefox":
browser = firefox
case "safari":
browser = safari
case "chrome":
fallthrough
default:
browser = chrome
}
remote.Transport = TransportConfig{
mode: "direct",
browser: browser,
}
}
// KeepAlive
if raw.KeepAlive <= 0 {
remote.KeepAlive = -1
} else {
remote.KeepAlive = remote.KeepAlive * time.Second
}
if raw.LocalHost == "" {
return nullErr("LocalHost")
}
if raw.LocalPort == "" {
return nullErr("LocalPort")
}
local.LocalAddr = net.JoinHostPort(raw.LocalHost, raw.LocalPort)
// stream no write timeout
if raw.StreamTimeout == 0 {
local.Timeout = 300 * time.Second
} else {
local.Timeout = time.Duration(raw.StreamTimeout) * time.Second
}
return
func (sta *State) putKeyPair(tthInterval int64, kp *keyPair) {
sta.keyPairsM.Lock()
sta.keyPairs[tthInterval] = kp
sta.keyPairsM.Unlock()
}

View File

@ -1,37 +0,0 @@
package client
import (
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseConfig(t *testing.T) {
ssv := "UID=iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=;PublicKey=IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=;" +
"ServerName=www.bing.com;NumConn=4;MaskBrowser=chrome;ProxyMethod=shadowsocks;EncryptionMethod=plain"
json := ssvToJson(ssv)
expected := []byte(`{"UID":"iGAO85zysIyR4c09CyZSLdNhtP/ckcYu7nIPI082AHA=","PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=","ServerName":"www.bing.com","NumConn":4,"MaskBrowser":"chrome","ProxyMethod":"shadowsocks","EncryptionMethod":"plain"}`)
t.Run("byte equality", func(t *testing.T) {
assert.Equal(t, expected, json)
})
t.Run("struct equality", func(t *testing.T) {
tmpConfig, _ := ioutil.TempFile("", "ck_client_config")
_, _ = tmpConfig.Write(expected)
parsedFromSSV, err := ParseConfig(ssv)
assert.NoError(t, err)
parsedFromJson, err := ParseConfig(tmpConfig.Name())
assert.NoError(t, err)
assert.Equal(t, parsedFromJson, parsedFromSSV)
})
t.Run("empty file", func(t *testing.T) {
tmpConfig, _ := ioutil.TempFile("", "ck_client_config")
_, err := ParseConfig(tmpConfig.Name())
assert.Error(t, err)
})
}

View File

@ -1,33 +0,0 @@
package client
import (
"net"
)
type Transport interface {
Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error)
net.Conn
}
type TransportConfig struct {
mode string
wsUrl string
browser browser
}
func (t TransportConfig) CreateTransport() Transport {
switch t.mode {
case "cdn":
return &WSOverTLS{
wsUrl: t.wsUrl,
}
case "direct":
return &DirectTLS{
browser: t.browser,
}
default:
return nil
}
}

View File

@ -1,84 +0,0 @@
package client
import (
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"github.com/cbeuw/Cloak/internal/common"
"github.com/gorilla/websocket"
utls "github.com/refraction-networking/utls"
)
type WSOverTLS struct {
*common.WebSocketConn
wsUrl string
}
func (ws *WSOverTLS) Handshake(rawConn net.Conn, authInfo AuthInfo) (sessionKey [32]byte, err error) {
utlsConfig := &utls.Config{
ServerName: authInfo.MockDomain,
InsecureSkipVerify: true,
}
uconn := utls.UClient(rawConn, utlsConfig, utls.HelloChrome_Auto)
err = uconn.BuildHandshakeState()
if err != nil {
return
}
for i, extension := range uconn.Extensions {
_, ok := extension.(*utls.ALPNExtension)
if ok {
uconn.Extensions = append(uconn.Extensions[:i], uconn.Extensions[i+1:]...)
break
}
}
err = uconn.Handshake()
if err != nil {
return
}
u, err := url.Parse(ws.wsUrl)
if err != nil {
return sessionKey, fmt.Errorf("failed to parse ws url: %v", err)
}
payload, sharedSecret := makeAuthenticationPayload(authInfo)
header := http.Header{}
header.Add("hidden", base64.StdEncoding.EncodeToString(append(payload.randPubKey[:], payload.ciphertextWithTag[:]...)))
c, _, err := websocket.NewClient(uconn, u, header, 16480, 16480)
if err != nil {
return sessionKey, fmt.Errorf("failed to handshake: %v", err)
}
ws.WebSocketConn = &common.WebSocketConn{Conn: c}
buf := make([]byte, 128)
n, err := ws.Read(buf)
if err != nil {
return sessionKey, fmt.Errorf("failed to read reply: %v", err)
}
if n != 60 {
return sessionKey, errors.New("reply must be 60 bytes")
}
reply := buf[:60]
sessionKeySlice, err := common.AESGCMDecrypt(reply[:12], sharedSecret[:], reply[12:])
if err != nil {
return
}
copy(sessionKey[:], sessionKeySlice)
return
}
func (ws *WSOverTLS) Close() error {
if ws.WebSocketConn != nil {
return ws.WebSocketConn.Close()
}
return nil
}

View File

@ -1,79 +0,0 @@
/*
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
Forked from https://golang.org/src/io/io.go
*/
package common
import (
"io"
"net"
)
func Copy(dst net.Conn, src net.Conn) (written int64, err error) {
defer func() { src.Close(); dst.Close() }()
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(dst)
}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
if rt, ok := dst.(io.ReaderFrom); ok {
return rt.ReadFrom(src)
}
size := 32 * 1024
buf := make([]byte, size)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}

View File

@ -1,97 +0,0 @@
package common
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"io"
"math/big"
"time"
log "github.com/sirupsen/logrus"
)
func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != aesgcm.NonceSize() {
// check here so it doesn't panic
return nil, errors.New("incorrect nonce size")
}
return aesgcm.Seal(nil, nonce, plaintext, nil), nil
}
func AESGCMDecrypt(nonce []byte, key []byte, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != aesgcm.NonceSize() {
// check here so it doesn't panic
return nil, errors.New("incorrect nonce size")
}
plain, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plain, nil
}
func CryptoRandRead(buf []byte) {
RandRead(rand.Reader, buf)
}
func backoff(f func() error) {
err := f()
if err == nil {
return
}
waitDur := [10]time.Duration{5 * time.Millisecond, 10 * time.Millisecond, 30 * time.Millisecond, 50 * time.Millisecond,
100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
3 * time.Second, 5 * time.Second}
for i := 0; i < 10; i++ {
log.Errorf("Failed to get random: %v. Retrying...", err)
err = f()
if err == nil {
return
}
time.Sleep(waitDur[i])
}
log.Fatal("Cannot get random after 10 retries")
}
func RandRead(randSource io.Reader, buf []byte) {
backoff(func() error {
_, err := randSource.Read(buf)
return err
})
}
func RandItem[T any](list []T) T {
return list[RandInt(len(list))]
}
func RandInt(n int) int {
s := new(int)
backoff(func() error {
size, err := rand.Int(rand.Reader, big.NewInt(int64(n)))
if err != nil {
return err
}
*s = int(size.Int64())
return nil
})
return *s
}

View File

@ -1,95 +0,0 @@
package common
import (
"bytes"
"encoding/hex"
"errors"
"io"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
const gcmTagSize = 16
func TestAESGCM(t *testing.T) {
// test vectors from https://luca-giuzzi.unibs.it/corsi/Support/papers-cryptography/gcm-spec.pdf
t.Run("correct 128", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("000000000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
encryptedWithTag, err := AESGCMEncrypt(nonce, key, plaintext)
assert.NoError(t, err)
assert.Equal(t, ciphertext, encryptedWithTag[:len(plaintext)])
assert.Equal(t, tag, encryptedWithTag[len(plaintext):len(plaintext)+gcmTagSize])
decrypted, err := AESGCMDecrypt(nonce, key, encryptedWithTag)
assert.NoError(t, err)
// slight inconvenience here that assert.Equal does not consider a nil slice and an empty slice to be
// equal. decrypted should be []byte(nil) but plaintext is []byte{}
assert.True(t, bytes.Equal(plaintext, decrypted))
})
t.Run("bad key size", func(t *testing.T) {
key, _ := hex.DecodeString("0000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("000000000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
_, err := AESGCMEncrypt(nonce, key, plaintext)
assert.Error(t, err)
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
t.Run("bad nonce size", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
plaintext, _ := hex.DecodeString("")
nonce, _ := hex.DecodeString("00000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("58e2fccefa7e3061367f1d57a4e7455a")
_, err := AESGCMEncrypt(nonce, key, plaintext)
assert.Error(t, err)
_, err = AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
t.Run("bad tag", func(t *testing.T) {
key, _ := hex.DecodeString("00000000000000000000000000000000")
nonce, _ := hex.DecodeString("00000000000000000000")
ciphertext, _ := hex.DecodeString("")
tag, _ := hex.DecodeString("fffffccefa7e3061367f1d57a4e745ff")
_, err := AESGCMDecrypt(nonce, key, append(ciphertext, tag...))
assert.Error(t, err)
})
}
type failingReader struct {
fails int
reader io.Reader
}
func (f *failingReader) Read(p []byte) (n int, err error) {
if f.fails > 0 {
f.fails -= 1
return 0, errors.New("no data for you yet")
} else {
return f.reader.Read(p)
}
}
func TestRandRead(t *testing.T) {
failer := &failingReader{
fails: 3,
reader: rand.New(rand.NewSource(0)),
}
readBuf := make([]byte, 10)
RandRead(failer, readBuf)
assert.NotEqual(t, [10]byte{}, readBuf)
}

View File

@ -1,7 +0,0 @@
package common
import "net"
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

View File

@ -1,112 +0,0 @@
package common
import (
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
const (
VersionTLS11 = 0x0301
VersionTLS13 = 0x0303
recordLayerLength = 5
Handshake = 22
ApplicationData = 23
initialWriteBufSize = 14336
)
func AddRecordLayer(input []byte, typ byte, ver uint16) []byte {
msgLen := len(input)
retLen := msgLen + recordLayerLength
var ret []byte
ret = make([]byte, retLen)
copy(ret[recordLayerLength:], input)
ret[0] = typ
ret[1] = byte(ver >> 8)
ret[2] = byte(ver)
ret[3] = byte(msgLen >> 8)
ret[4] = byte(msgLen)
return ret
}
type TLSConn struct {
net.Conn
writeBufPool sync.Pool
}
func NewTLSConn(conn net.Conn) *TLSConn {
return &TLSConn{
Conn: conn,
writeBufPool: sync.Pool{New: func() interface{} {
b := make([]byte, 0, initialWriteBufSize)
b = append(b, ApplicationData, byte(VersionTLS13>>8), byte(VersionTLS13&0xFF))
return &b
}},
}
}
func (tls *TLSConn) LocalAddr() net.Addr {
return tls.Conn.LocalAddr()
}
func (tls *TLSConn) RemoteAddr() net.Addr {
return tls.Conn.RemoteAddr()
}
func (tls *TLSConn) SetDeadline(t time.Time) error {
return tls.Conn.SetDeadline(t)
}
func (tls *TLSConn) SetReadDeadline(t time.Time) error {
return tls.Conn.SetReadDeadline(t)
}
func (tls *TLSConn) SetWriteDeadline(t time.Time) error {
return tls.Conn.SetWriteDeadline(t)
}
func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
// TCP is a stream. Multiple TLS messages can arrive at the same time,
// a single message can also be segmented due to MTU of the IP layer.
// This function guareentees a single TLS message to be read and everything
// else is left in the buffer.
if len(buffer) < recordLayerLength {
return 0, io.ErrShortBuffer
}
_, err = io.ReadFull(tls.Conn, buffer[:recordLayerLength])
if err != nil {
return
}
dataLength := int(binary.BigEndian.Uint16(buffer[3:5]))
if dataLength > len(buffer) {
err = io.ErrShortBuffer
return
}
// we overwrite the record layer here
return io.ReadFull(tls.Conn, buffer[:dataLength])
}
func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in)
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
return 0, errors.New("message is too long")
}
writeBuf := tls.writeBufPool.Get().(*[]byte)
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
*writeBuf = append(*writeBuf, in...)
n, err = tls.Conn.Write(*writeBuf)
*writeBuf = (*writeBuf)[:3]
tls.writeBufPool.Put(writeBuf)
return n - recordLayerLength, err
}
func (tls *TLSConn) Close() error {
return tls.Conn.Close()
}

View File

@ -1,40 +0,0 @@
package common
import (
"net"
"testing"
)
func BenchmarkTLSConn_Write(b *testing.B) {
const bufSize = 16 * 1024
addrCh := make(chan string, 1)
go func() {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatal(err)
}
addrCh <- listener.Addr().String()
conn, err := listener.Accept()
if err != nil {
b.Fatal(err)
}
readBuf := make([]byte, bufSize*2)
for {
_, err = conn.Read(readBuf)
if err != nil {
return
}
}
}()
data := make([]byte, bufSize)
discardConn, _ := net.Dial("tcp", <-addrCh)
tlsConn := NewTLSConn(discardConn)
defer tlsConn.Close()
b.SetBytes(bufSize)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
tlsConn.Write(data)
}
})
}

View File

@ -1,77 +0,0 @@
package common
import (
"errors"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocketConn implements io.ReadWriteCloser
// it makes websocket.Conn binary-oriented
type WebSocketConn struct {
*websocket.Conn
writeM sync.Mutex
}
func (ws *WebSocketConn) Write(data []byte) (int, error) {
ws.writeM.Lock()
err := ws.WriteMessage(websocket.BinaryMessage, data)
ws.writeM.Unlock()
if err != nil {
return 0, err
} else {
return len(data), nil
}
}
func (ws *WebSocketConn) Read(buf []byte) (n int, err error) {
t, r, err := ws.NextReader()
if err != nil {
return 0, err
}
if t != websocket.BinaryMessage {
return 0, nil
}
// Read until io.EOL for one full message
for {
var read int
read, err = r.Read(buf[n:])
if err != nil {
if err == io.EOF {
err = nil
break
} else {
break
}
} else {
// There may be data available to read but n == len(buf)-1, read==0 because buffer is full
if read == 0 {
err = errors.New("nothing more is read. message may be larger than buffer")
break
}
}
n += read
}
return
}
func (ws *WebSocketConn) Close() error {
ws.writeM.Lock()
defer ws.writeM.Unlock()
return ws.Conn.Close()
}
func (ws *WebSocketConn) SetDeadline(t time.Time) error {
err := ws.SetReadDeadline(t)
if err != nil {
return err
}
err = ws.SetWriteDeadline(t)
if err != nil {
return err
}
return nil
}

View File

@ -1,24 +0,0 @@
package common
import (
"crypto/rand"
"io"
"time"
)
var RealWorldState = WorldState{
Rand: rand.Reader,
Now: time.Now,
}
type WorldState struct {
Rand io.Reader
Now func() time.Time
}
func WorldOfTime(t time.Time) WorldState {
return WorldState{
Rand: rand.Reader,
Now: func() time.Time { return t },
}
}

View File

@ -1,78 +0,0 @@
// This code is forked from https://github.com/wsddn/go-ecdh/blob/master/curve25519.go
/*
Copyright (c) 2014, tang0th
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of tang0th nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package ecdh
import (
"crypto"
"io"
"golang.org/x/crypto/curve25519"
)
func GenerateKey(rand io.Reader) (crypto.PrivateKey, crypto.PublicKey, error) {
var pub, priv [32]byte
var err error
_, err = io.ReadFull(rand, priv[:])
if err != nil {
return nil, nil, err
}
priv[0] &= 248
priv[31] &= 127
priv[31] |= 64
curve25519.ScalarBaseMult(&pub, &priv)
return &priv, &pub, nil
}
func Marshal(p crypto.PublicKey) []byte {
pub := p.(*[32]byte)
return pub[:]
}
func Unmarshal(data []byte) (crypto.PublicKey, bool) {
var pub [32]byte
if len(data) != 32 {
return nil, false
}
copy(pub[:], data)
return &pub, true
}
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) {
var priv, pub *[32]byte
priv = privKey.(*[32]byte)
pub = pubKey.(*[32]byte)
return curve25519.X25519(priv[:], pub[:])
}

View File

@ -1,105 +0,0 @@
// This code is forked from https://github.com/wsddn/go-ecdh/blob/master/curve25519.go
/*
Copyright (c) 2014, tang0th
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of tang0th nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package ecdh
import (
"bytes"
"crypto"
"crypto/rand"
"io"
"testing"
)
func TestCurve25519(t *testing.T) {
testECDH(t)
}
func TestErrors(t *testing.T) {
reader, writer := io.Pipe()
_ = writer.Close()
_, _, err := GenerateKey(reader)
if err == nil {
t.Error("GenerateKey should return error")
}
_, ok := Unmarshal([]byte{1})
if ok {
t.Error("Unmarshal should return false")
}
}
func BenchmarkCurve25519(b *testing.B) {
for i := 0; i < b.N; i++ {
testECDH(b)
}
}
func testECDH(t testing.TB) {
var privKey1, privKey2 crypto.PrivateKey
var pubKey1, pubKey2 crypto.PublicKey
var pubKey1Buf, pubKey2Buf []byte
var err error
var ok bool
var secret1, secret2 []byte
privKey1, pubKey1, err = GenerateKey(rand.Reader)
if err != nil {
t.Error(err)
}
privKey2, pubKey2, err = GenerateKey(rand.Reader)
if err != nil {
t.Error(err)
}
pubKey1Buf = Marshal(pubKey1)
pubKey2Buf = Marshal(pubKey2)
pubKey1, ok = Unmarshal(pubKey1Buf)
if !ok {
t.Fatalf("Unmarshal does not work")
}
pubKey2, ok = Unmarshal(pubKey2Buf)
if !ok {
t.Fatalf("Unmarshal does not work")
}
secret1, err = GenerateSharedSecret(privKey1, pubKey2)
if err != nil {
t.Error(err)
}
secret2, err = GenerateSharedSecret(privKey2, pubKey1)
if err != nil {
t.Error(err)
}
if !bytes.Equal(secret1, secret2) {
t.Fatalf("The two shared keys: %d, %d do not match", secret1, secret2)
}
}

View File

@ -1,119 +0,0 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"bytes"
"io"
"sync"
"time"
)
// datagramBufferedPipe is the same as streamBufferedPipe with the exception that it's message-oriented,
// instead of byte-oriented. The integrity of datagrams written into this buffer is preserved.
// it won't get chopped up into individual bytes
type datagramBufferedPipe struct {
pLens []int
buf *bytes.Buffer
closed bool
rwCond *sync.Cond
wtTimeout time.Duration
rDeadline time.Time
timeoutTimer *time.Timer
}
func NewDatagramBufferedPipe() *datagramBufferedPipe {
d := &datagramBufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer),
}
return d
}
func (d *datagramBufferedPipe) Read(target []byte) (int, error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed && len(d.pLens) == 0 {
return 0, io.EOF
}
hasRDeadline := !d.rDeadline.IsZero()
if hasRDeadline {
if time.Until(d.rDeadline) <= 0 {
return 0, ErrTimeout
}
}
if len(d.pLens) > 0 {
break
}
if hasRDeadline {
d.broadcastAfter(time.Until(d.rDeadline))
}
d.rwCond.Wait()
}
dataLen := d.pLens[0]
if len(target) < dataLen {
return 0, io.ErrShortBuffer
}
d.pLens = d.pLens[1:]
d.buf.Read(target[:dataLen])
// err will always be nil because we have already verified that buf.Len() != 0
d.rwCond.Broadcast()
return dataLen, nil
}
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
for {
if d.closed {
return true, io.ErrClosedPipe
}
if d.buf.Len() <= recvBufferSizeLimit {
// if d.buf gets too large, write() will panic. We don't want this to happen
break
}
d.rwCond.Wait()
}
if f.Closing != closingNothing {
d.closed = true
d.rwCond.Broadcast()
return true, nil
}
dataLen := len(f.Payload)
d.pLens = append(d.pLens, dataLen)
d.buf.Write(f.Payload)
// err will always be nil
d.rwCond.Broadcast()
return false, nil
}
func (d *datagramBufferedPipe) Close() error {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.closed = true
d.rwCond.Broadcast()
return nil
}
func (d *datagramBufferedPipe) SetReadDeadline(t time.Time) {
d.rwCond.L.Lock()
defer d.rwCond.L.Unlock()
d.rDeadline = t
d.rwCond.Broadcast()
}
func (d *datagramBufferedPipe) broadcastAfter(t time.Duration) {
if d.timeoutTimer != nil {
d.timeoutTimer.Stop()
}
d.timeoutTimer = time.AfterFunc(t, d.rwCond.Broadcast)
}

View File

@ -1,62 +0,0 @@
package multiplex
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDatagramBuffer_RW(t *testing.T) {
b := []byte{0x01, 0x02, 0x03}
t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
_, err := pipe.Write(&Frame{Payload: b})
assert.NoError(t, err)
})
t.Run("simple read", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
_, _ = pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n)
assert.Equal(t, b, b2)
assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
})
t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
assert.NoError(t, err)
assert.True(t, toBeClosed, "should be to be closed")
assert.True(t, pipe.closed, "pipe should be closed")
})
}
func TestDatagramBuffer_BlockingRead(t *testing.T) {
pipe := NewDatagramBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(readBlockTime)
pipe.Write(&Frame{Payload: b})
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
pipe := NewDatagramBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}

View File

@ -1,14 +1,10 @@
package multiplex
const (
closingNothing = iota
closingStream
closingSession
)
import ()
type Frame struct {
StreamID uint32
Seq uint64
Closing uint8
Seq uint32
Closing uint32
Payload []byte
}

View File

@ -0,0 +1,129 @@
package multiplex
import (
"container/heap"
"log"
)
// The data is multiplexed through several TCP connections, therefore the
// order of arrival is not guaranteed. A stream's first packet may be sent through
// connection0 and its second packet may be sent through connection1. Although both
// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the
// remote side before packet0.
//
// However, shadowsocks' protocol does not provide sequence control. We must therefore
// make sure packets arrive in order.
//
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
// they should be sent to shadowsocks. The code in this file provides buffering and sorting.
//
// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around.
//
// Note that in golang, integer overflow results in wrap around
//
// Stream.nextRecvSeq is the expected sequence number of the next packet
// Stream.rev counts the amount of time the sequence number gets wrapped
type frameNode struct {
seq uint32
trueSeq uint64
frame *Frame
}
type sorterHeap []*frameNode
func (sh sorterHeap) Less(i, j int) bool {
return sh[i].trueSeq < sh[j].trueSeq
}
func (sh sorterHeap) Len() int {
return len(sh)
}
func (sh sorterHeap) Swap(i, j int) {
sh[i], sh[j] = sh[j], sh[i]
}
func (sh *sorterHeap) Push(x interface{}) {
*sh = append(*sh, x.(*frameNode))
}
func (sh *sorterHeap) Pop() interface{} {
old := *sh
n := len(old)
x := old[n-1]
*sh = old[0 : n-1]
return x
}
func (s *Stream) writeNewFrame(f *Frame) {
s.newFrameCh <- f
}
// recvNewFrame is a forever running loop which receives frames unordered,
// cache and order them and send them into sortedBufCh
func (s *Stream) recvNewFrame() {
for {
var f *Frame
select {
case <-s.die:
return
case f = <-s.newFrameCh:
}
if f == nil {
log.Println("nil frame")
continue
}
// when there's no ooo packages in heap and we receive the next package in order
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
s.pushFrame(f)
continue
}
fs := &frameNode{
f.Seq,
0,
f,
}
if fs.seq < s.nextRecvSeq {
// For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255
// e.g. we are on rev=0 (wrap has not happened yet)
// and we get the order of recv as 253 254 0 1
// after 254, nextN should be 255, but 0 is received and 0 < 255
// now 0 should have a trueSeq of 256
if !s.wrapMode {
// wrapMode is true when the latest seq is wrapped but nextN is not
s.wrapMode = true
}
fs.trueSeq = uint64(1<<16*(s.rev+1)) + uint64(fs.seq) + 1
// +1 because wrapped 0 should have trueSeq of 256 instead of 255
// when this bit was run on 1, the trueSeq of 1 would become 256
} else {
fs.trueSeq = uint64(1<<16*s.rev) + uint64(fs.seq)
// when this bit was run on 255, the trueSeq of 255 would be 255
}
heap.Push(&s.sh, fs)
// Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(s.sh) > 0 && s.sh[0].seq == s.nextRecvSeq {
frame := heap.Pop(&s.sh).(*frameNode).frame
s.pushFrame(frame)
}
}
}
func (s *Stream) pushFrame(f *Frame) {
if f.Closing == 1 {
s.sortedBufCh <- []byte{}
return
}
s.sortedBufCh <- f.Payload
s.nextRecvSeq += 1
if s.nextRecvSeq == 0 {
// when nextN is wrapped, wrapMode becomes false and rev+1
s.rev += 1
s.wrapMode = false
}
}

View File

@ -1,150 +0,0 @@
package multiplex
import (
"bytes"
"io"
"math/rand"
"net"
"sync"
"testing"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
func serveEcho(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
// TODO: pass the error back
return
}
go func(conn net.Conn) {
_, err := io.Copy(conn, conn)
if err != nil {
// TODO: pass the error back
return
}
}(conn)
}
}
type connPair struct {
clientConn net.Conn
serverConn net.Conn
}
func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
sessionKey := [32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
sessionId := 1
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
clientConfig := SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: false,
}
serverConfig := clientConfig
clientSession := MakeSession(uint32(sessionId), clientConfig)
serverSession := MakeSession(uint32(sessionId), serverConfig)
paris := make([]*connPair, numConn)
for i := 0; i < numConn; i++ {
c, s := connutil.AsyncPipe()
clientConn := common.NewTLSConn(c)
serverConn := common.NewTLSConn(s)
paris[i] = &connPair{
clientConn: clientConn,
serverConn: serverConn,
}
clientSession.AddConnection(clientConn)
serverSession.AddConnection(serverConn)
}
return clientSession, serverSession, paris
}
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
testData := make([]byte, msgLen)
rand.Read(testData)
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData)
if n != msgLen {
t.Errorf("written only %v, err %v", n, err)
return
}
recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Errorf("failed to read back: %v", err)
return
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("echoed data not correct")
return
}
}(conn)
}
wg.Wait()
}
func TestMultiplex(t *testing.T) {
const numStreams = 2000 // -race option limits the number of goroutines to 8192
const numConns = 4
const msgLen = 16384
clientSession, serverSession, _ := makeSessionPair(numConns)
go serveEcho(serverSession)
streams := make([]net.Conn, numStreams)
for i := 0; i < numStreams; i++ {
stream, err := clientSession.OpenStream()
assert.NoError(t, err)
streams[i] = stream
}
//test echo
runEchoTest(t, streams, msgLen)
assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
// close one stream
closing, streams := streams[0], streams[1:]
err := closing.Close()
assert.NoError(t, err, "couldn't close a stream")
_, err = closing.Write([]byte{0})
assert.Equal(t, ErrBrokenStream, err)
_, err = closing.Read(make([]byte, 1))
assert.Equal(t, ErrBrokenStream, err)
}
func TestMux_StreamClosing(t *testing.T) {
clientSession, serverSession, _ := makeSessionPair(1)
go serveEcho(serverSession)
// read after closing stream
testData := make([]byte, 128)
recvBuf := make([]byte, 128)
toBeClosed, _ := clientSession.OpenStream()
_, err := toBeClosed.Write(testData) // should be echoed back
assert.NoError(t, err, "couldn't write to a stream")
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
assert.NoError(t, err, "can't read anything before stream closed")
_ = toBeClosed.Close()
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
assert.NoError(t, err, "can't read residual data on stream")
assert.Equal(t, testData, recvBuf, "incorrect data read back")
}

View File

@ -1,201 +1,76 @@
package multiplex
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/common"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/salsa20"
xxhash "github.com/OneOfOne/xxhash"
)
const frameHeaderLength = 14
const salsa20NonceSize = 8
type Obfser func(*Frame) ([]byte, error)
type Deobfser func([]byte) (*Frame, error)
// maxExtraLen equals the max length of padding + AEAD tag.
// It is 255 bytes because the extra len field in frame header is only one byte.
const maxExtraLen = 1<<8 - 1
// padFirstNFrames specifies the number of initial frames to pad,
// to avoid TLS-in-TLS detection
const padFirstNFrames = 5
const (
EncryptionMethodPlain = iota
EncryptionMethodAES256GCM
EncryptionMethodChaha20Poly1305
EncryptionMethodAES128GCM
)
// Obfuscator is responsible for serialisation, obfuscation, and optional encryption of data frames.
type Obfuscator struct {
payloadCipher cipher.AEAD
sessionKey [32]byte
// For each frame, the three parts of the header is xored with three keys.
// The keys are generated from the SID and the payload of the frame.
func genXorKeys(secret []byte, data []byte) (i uint32, ii uint32, iii uint32) {
h := xxhash.New32()
ret := make([]uint32, 3)
preHash := make([]byte, 16)
for j := 0; j < 3; j++ {
copy(preHash[0:10], secret[j*10:j*10+10])
copy(preHash[10:16], data[j*6:j*6+6])
h.Write(preHash)
ret[j] = h.Sum32()
}
return ret[0], ret[1], ret[2]
}
// obfuscate adds multiplexing headers, encrypt and add TLS header
func (o *Obfuscator) obfuscate(f *Frame, buf []byte, payloadOffsetInBuf int) (int, error) {
// The method here is to use the first payloadCipher.NonceSize() bytes of the serialised frame header
// as iv/nonce for the AEAD cipher to encrypt the frame payload. Then we use
// the authentication tag produced appended to the end of the ciphertext (of size payloadCipher.Overhead())
// as nonce for Salsa20 to encrypt the frame header. Both with sessionKey as keys.
//
// Several cryptographic guarantees we have made here: that payloadCipher, as an AEAD, is given a unique
// iv/nonce each time, relative to its key; that the frame header encryptor Salsa20 is given a unique
// nonce each time, relative to its key; and that the authenticity of frame header is checked.
//
// The payloadCipher is given a unique iv/nonce each time because it is derived from the frame header, which
// contains the monotonically increasing stream id (uint32) and frame sequence (uint64). There will be a nonce
// reuse after 2^64-1 frames sent (sent, not received because frames going different ways are sequenced
// independently) by a stream, or after 2^32-1 streams created in a single session. We consider these number
// to be large enough that they may never happen in reasonable time frames. Of course, different sessions
// will produce the same combination of stream id and frame sequence, but they will have different session keys.
//
//
// Because the frame header, before it being encrypted, is fed into the AEAD, it is also authenticated.
// (rfc5116 s.2.1 "The nonce is authenticated internally to the algorithm").
//
// In case the user chooses to not encrypt the frame payload, payloadCipher will be nil. In this scenario,
// we generate random bytes to be used as salsa20 nonce.
payloadLen := len(f.Payload)
if payloadLen == 0 {
return 0, errors.New("payload cannot be empty")
func MakeObfs(key []byte) Obfser {
obfs := func(f *Frame) ([]byte, error) {
if len(f.Payload) < 18 {
return nil, errors.New("Payload cannot be shorter than 18 bytes")
}
tagLen := 0
if o.payloadCipher != nil {
tagLen = o.payloadCipher.Overhead()
} else {
tagLen = salsa20NonceSize
obfsedHeader := make([]byte, 12)
// header: [StreamID 4 bytes][Seq 4 bytes][Closing 4 bytes]
i, ii, iii := genXorKeys(key, f.Payload[0:18])
binary.BigEndian.PutUint32(obfsedHeader[0:4], f.StreamID^i)
binary.BigEndian.PutUint32(obfsedHeader[4:8], f.Seq^ii)
binary.BigEndian.PutUint32(obfsedHeader[8:12], f.Closing^iii)
// Composing final obfsed message
// We don't use util.AddRecordLayer here to avoid unnecessary malloc
obfsed := make([]byte, 5+12+len(f.Payload))
obfsed[0] = 0x17
obfsed[1] = 0x03
obfsed[2] = 0x03
binary.BigEndian.PutUint16(obfsed[3:5], uint16(12+len(f.Payload)))
copy(obfsed[5:17], obfsedHeader)
copy(obfsed[17:], f.Payload)
// obfsed: [record layer 5 bytes][cipherheader 12 bytes][payload]
return obfsed, nil
}
// Pad to avoid size side channel leak
padLen := 0
if f.Seq < padFirstNFrames {
padLen = common.RandInt(maxExtraLen - tagLen + 1)
return obfs
}
usefulLen := frameHeaderLength + payloadLen + padLen + tagLen
if len(buf) < usefulLen {
return 0, errors.New("obfs buffer too small")
func MakeDeobfs(key []byte) Deobfser {
deobfs := func(in []byte) (*Frame, error) {
if len(in) < 30 {
return nil, errors.New("Input cannot be shorter than 30 bytes")
}
// we do as much in-place as possible to save allocation
payload := buf[frameHeaderLength : frameHeaderLength+payloadLen+padLen]
if payloadOffsetInBuf != frameHeaderLength {
// if payload is not at the correct location in buffer
copy(payload, f.Payload)
peeled := in[5:]
i, ii, iii := genXorKeys(key, peeled[12:30])
streamID := binary.BigEndian.Uint32(peeled[0:4]) ^ i
seq := binary.BigEndian.Uint32(peeled[4:8]) ^ ii
closing := binary.BigEndian.Uint32(peeled[8:12]) ^ iii
payload := make([]byte, len(peeled)-12)
copy(payload, peeled[12:])
ret := &Frame{
StreamID: streamID,
Seq: seq,
Closing: closing,
Payload: payload,
}
header := buf[:frameHeaderLength]
binary.BigEndian.PutUint32(header[0:4], f.StreamID)
binary.BigEndian.PutUint64(header[4:12], f.Seq)
header[12] = f.Closing
header[13] = byte(padLen + tagLen)
// Random bytes for padding and nonce
_, err := rand.Read(buf[frameHeaderLength+payloadLen : usefulLen])
if err != nil {
return 0, fmt.Errorf("failed to pad random: %w", err)
return ret, nil
}
if o.payloadCipher != nil {
o.payloadCipher.Seal(payload[:0], header[:o.payloadCipher.NonceSize()], payload, nil)
}
nonce := buf[usefulLen-salsa20NonceSize : usefulLen]
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey)
return usefulLen, nil
}
// deobfuscate removes TLS header, decrypt and unmarshall frames
func (o *Obfuscator) deobfuscate(f *Frame, in []byte) error {
if len(in) < frameHeaderLength+salsa20NonceSize {
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), frameHeaderLength+salsa20NonceSize)
}
header := in[:frameHeaderLength]
pldWithOverHead := in[frameHeaderLength:] // payload + potential overhead
nonce := in[len(in)-salsa20NonceSize:]
salsa20.XORKeyStream(header, header, nonce, &o.sessionKey)
streamID := binary.BigEndian.Uint32(header[0:4])
seq := binary.BigEndian.Uint64(header[4:12])
closing := header[12]
extraLen := header[13]
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
}
var outputPayload []byte
if o.payloadCipher == nil {
if extraLen == 0 {
outputPayload = pldWithOverHead
} else {
outputPayload = pldWithOverHead[:usefulPayloadLen]
}
} else {
_, err := o.payloadCipher.Open(pldWithOverHead[:0], header[:o.payloadCipher.NonceSize()], pldWithOverHead, nil)
if err != nil {
return err
}
outputPayload = pldWithOverHead[:usefulPayloadLen]
}
f.StreamID = streamID
f.Seq = seq
f.Closing = closing
f.Payload = outputPayload
return nil
}
func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (o Obfuscator, err error) {
o = Obfuscator{
sessionKey: sessionKey,
}
switch encryptionMethod {
case EncryptionMethodPlain:
o.payloadCipher = nil
case EncryptionMethodAES256GCM:
var c cipher.Block
c, err = aes.NewCipher(sessionKey[:])
if err != nil {
return
}
o.payloadCipher, err = cipher.NewGCM(c)
if err != nil {
return
}
case EncryptionMethodAES128GCM:
var c cipher.Block
c, err = aes.NewCipher(sessionKey[:16])
if err != nil {
return
}
o.payloadCipher, err = cipher.NewGCM(c)
if err != nil {
return
}
case EncryptionMethodChaha20Poly1305:
o.payloadCipher, err = chacha20poly1305.New(sessionKey[:])
if err != nil {
return
}
default:
return o, fmt.Errorf("unknown encryption method valued %v", encryptionMethod)
}
if o.payloadCipher != nil {
if o.payloadCipher.NonceSize() > frameHeaderLength {
return o, errors.New("payload AEAD's nonce size cannot be greater than size of frame header")
}
}
return
return deobfs
}

View File

@ -1,276 +0,0 @@
package multiplex
import (
"crypto/aes"
"crypto/cipher"
"math/rand"
"reflect"
"testing"
"testing/quick"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
)
func TestGenerateObfs(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
run := func(o Obfuscator, t *testing.T) {
obfsBuf := make([]byte, 512)
_testFrame, _ := quick.Value(reflect.TypeOf(Frame{}), rand.New(rand.NewSource(42)))
testFrame := _testFrame.Interface().(Frame)
i, err := o.obfuscate(&testFrame, obfsBuf, 0)
assert.NoError(t, err)
var resultFrame Frame
err = o.deobfuscate(&resultFrame, obfsBuf[:i])
assert.NoError(t, err)
assert.EqualValues(t, testFrame, resultFrame)
}
t.Run("plain", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodPlain, sessionKey)
assert.NoError(t, err)
run(o, t)
})
t.Run("aes-256-gcm", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodAES256GCM, sessionKey)
assert.NoError(t, err)
run(o, t)
})
t.Run("aes-128-gcm", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodAES128GCM, sessionKey)
assert.NoError(t, err)
run(o, t)
})
t.Run("chacha20-poly1305", func(t *testing.T) {
o, err := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
assert.NoError(t, err)
run(o, t)
})
t.Run("unknown encryption method", func(t *testing.T) {
_, err := MakeObfuscator(0xff, sessionKey)
assert.Error(t, err)
})
}
func TestObfuscate(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
const testPayloadLen = 1024
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
f := Frame{
StreamID: 0,
Seq: 0,
Closing: 0,
Payload: testPayload,
}
runTest := func(t *testing.T, o Obfuscator) {
obfsBuf := make([]byte, testPayloadLen*2)
n, err := o.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
resultFrame := Frame{}
err = o.deobfuscate(&resultFrame, obfsBuf[:n])
assert.NoError(t, err)
assert.EqualValues(t, f, resultFrame)
}
t.Run("plain", func(t *testing.T) {
o := Obfuscator{
payloadCipher: nil,
sessionKey: sessionKey,
}
runTest(t, o)
})
t.Run("aes-128-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:16])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
t.Run("aes-256-gcm", func(t *testing.T) {
c, err := aes.NewCipher(sessionKey[:])
assert.NoError(t, err)
payloadCipher, err := cipher.NewGCM(c)
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
t.Run("chacha20-poly1305", func(t *testing.T) {
payloadCipher, err := chacha20poly1305.New(sessionKey[:])
assert.NoError(t, err)
o := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: sessionKey,
}
runTest(t, o)
})
}
func BenchmarkObfs(b *testing.B) {
testPayload := make([]byte, 1024)
rand.Read(testPayload)
testFrame := &Frame{
1,
0,
0,
testPayload,
}
obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte
rand.Read(key[:])
b.Run("AES256GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0)
}
})
b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0)
}
})
b.Run("plain", func(b *testing.B) {
obfuscator := Obfuscator{
payloadCipher: nil,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0)
}
})
b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:])
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
b.SetBytes(int64(len(testFrame.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
obfuscator.obfuscate(testFrame, obfsBuf, 0)
}
})
}
func BenchmarkDeobfs(b *testing.B) {
testPayload := make([]byte, 1024)
rand.Read(testPayload)
testFrame := &Frame{
1,
0,
0,
testPayload,
}
obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte
rand.Read(key[:])
b.Run("AES256GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.SetBytes(int64(n))
b.ResetTimer()
for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n])
}
})
b.Run("AES128GCM", func(b *testing.B) {
c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c)
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n])
}
})
b.Run("plain", func(b *testing.B) {
obfuscator := Obfuscator{
payloadCipher: nil,
sessionKey: key,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n])
}
})
b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:])
obfuscator := Obfuscator{
payloadCipher: payloadCipher,
sessionKey: key,
}
n, _ := obfuscator.obfuscate(testFrame, obfsBuf, 0)
frame := new(Frame)
b.ResetTimer()
b.SetBytes(int64(n))
for i := 0; i < b.N; i++ {
obfuscator.deobfuscate(frame, obfsBuf[:n])
}
})
}

View File

@ -7,60 +7,48 @@ import (
)
// Valve needs to be universal, across all sessions that belong to a user
type LimitedValve struct {
// traffic directions from the server's perspective are referred
// gabe please don't sue
type Valve struct {
// traffic directions from the server's perspective are refered
// exclusively as rx and tx.
// rx is from client to server, tx is from server to client
// DO NOT use terms up or down as this is used in usermanager
// for bandwidth limiting
rxtb *ratelimit.Bucket
txtb *ratelimit.Bucket
rxtb atomic.Value // *ratelimit.Bucket
txtb atomic.Value // *ratelimit.Bucket
rx *int64
tx *int64
rxCredit *int64
txCredit *int64
}
type UnlimitedValve struct{}
func MakeValve(rxRate, txRate int64) *LimitedValve {
var rx, tx int64
v := &LimitedValve{
rxtb: ratelimit.NewBucketWithRate(float64(rxRate), rxRate),
txtb: ratelimit.NewBucketWithRate(float64(txRate), txRate),
rx: &rx,
tx: &tx,
func MakeValve(rxRate, txRate int64, rxCredit, txCredit *int64) *Valve {
v := &Valve{
rxCredit: rxCredit,
txCredit: txCredit,
}
v.SetRxRate(rxRate)
v.SetTxRate(txRate)
return v
}
var UNLIMITED_VALVE = &UnlimitedValve{}
func (v *Valve) SetRxRate(rate int64) { v.rxtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
func (v *LimitedValve) rxWait(n int) { v.rxtb.Wait(int64(n)) }
func (v *LimitedValve) txWait(n int) { v.txtb.Wait(int64(n)) }
func (v *LimitedValve) AddRx(n int64) { atomic.AddInt64(v.rx, n) }
func (v *LimitedValve) AddTx(n int64) { atomic.AddInt64(v.tx, n) }
func (v *LimitedValve) GetRx() int64 { return atomic.LoadInt64(v.rx) }
func (v *LimitedValve) GetTx() int64 { return atomic.LoadInt64(v.tx) }
func (v *LimitedValve) Nullify() (int64, int64) {
rx := atomic.SwapInt64(v.rx, 0)
tx := atomic.SwapInt64(v.tx, 0)
return rx, tx
}
func (v *Valve) SetTxRate(rate int64) { v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
func (v *UnlimitedValve) rxWait(n int) {}
func (v *UnlimitedValve) txWait(n int) {}
func (v *UnlimitedValve) AddRx(n int64) {}
func (v *UnlimitedValve) AddTx(n int64) {}
func (v *UnlimitedValve) GetRx() int64 { return 0 }
func (v *UnlimitedValve) GetTx() int64 { return 0 }
func (v *UnlimitedValve) Nullify() (int64, int64) { return 0, 0 }
func (v *Valve) rxWait(n int) { v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
type Valve interface {
rxWait(n int)
txWait(n int)
AddRx(n int64)
AddTx(n int64)
GetRx() int64
GetTx() int64
Nullify() (int64, int64)
}
func (v *Valve) txWait(n int) { v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
func (v *Valve) SetRxCredit(n int64) { atomic.StoreInt64(v.rxCredit, n) }
func (v *Valve) SetTxCredit(n int64) { atomic.StoreInt64(v.txCredit, n) }
func (v *Valve) GetRxCredit() int64 { return atomic.LoadInt64(v.rxCredit) }
func (v *Valve) GetTxCredit() int64 { return atomic.LoadInt64(v.txCredit) }
// n can be negative
func (v *Valve) AddRxCredit(n int64) int64 { return atomic.AddInt64(v.rxCredit, n) }
// n can be negative
func (v *Valve) AddTxCredit(n int64) int64 { return atomic.AddInt64(v.txCredit, n) }

View File

@ -1,24 +0,0 @@
package multiplex
import (
"errors"
"io"
"time"
)
var ErrTimeout = errors.New("deadline exceeded")
type recvBuffer interface {
// Read calls' err must be nil | io.EOF | io.ErrShortBuffer
// Read should NOT return error on a closed streamBuffer with a non-empty buffer.
// Instead, it should behave as if it hasn't been closed. Closure is only relevant
// when the buffer is empty.
io.ReadCloser
Write(*Frame) (toBeClosed bool, err error)
SetReadDeadline(time time.Time)
}
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
// a panic will happen.
const recvBufferSizeLimit = 1<<31 - 1

View File

@ -2,351 +2,164 @@ package multiplex
import (
"errors"
"fmt"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
)
const (
// Copied from smux
acceptBacklog = 1024
defaultInactivityTimeout = 30 * time.Second
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
)
var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type SessionConfig struct {
Obfuscator
// Valve is used to limit transmission rates, and record and limit usage
Valve
Unordered bool
// A Singleplexing session always has just one stream
Singleplex bool
// maximum size of an obfuscated frame, including headers and overhead
MsgOnWireSizeLimit int
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
InactivityTimeout time.Duration
}
// A Session represents a self-contained communication chain between local and remote. It manages its streams,
// controls serialisation and encryption of data sent and received using the supplied Obfuscator, and send and receive
// data through a manged connection pool filled with underlying connections added to it.
type Session struct {
id uint32
SessionConfig
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
obfs Obfser
// Remove TLS header, decrypt and unmarshall multiplexing headers
deobfs Deobfser
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
obfsedRead func(net.Conn, []byte) (int, error)
// atomic
nextStreamID uint32
// atomic
activeStreamCount uint32
streamsM sync.Mutex
streamsM sync.RWMutex
streams map[uint32]*Stream
// For accepting new streams
acceptCh chan *Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
streamObfsBufPool sync.Pool
// Switchboard manages all connections to remote
sb *switchboard
// Used for LocalAddr() and RemoteAddr() etc.
addrs atomic.Value
// For accepting new streams
acceptCh chan *Stream
closed uint32
terminalMsgSetter sync.Once
terminalMsg string
// the max size passed to Write calls before it splits it into multiple frames
// i.e. the max size a piece of data can fit into a Frame.Payload
maxStreamUnitWrite int
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
streamSendBufferSize int
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
connReceiveBufferSize int
die chan struct{}
suicide sync.Once
}
func MakeSession(id uint32, config SessionConfig) *Session {
// 1 conn is needed to make a session
func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) *Session {
sesh := &Session{
id: id,
SessionConfig: config,
obfs: obfs,
deobfs: deobfs,
obfsedRead: obfsedRead,
nextStreamID: 1,
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
streams: map[uint32]*Stream{},
die: make(chan struct{}),
}
sesh.addrs.Store([]net.Addr{nil, nil})
if config.Valve == nil {
sesh.Valve = UNLIMITED_VALVE
}
if config.MsgOnWireSizeLimit <= 0 {
sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
}
if config.InactivityTimeout == 0 {
sesh.InactivityTimeout = defaultInactivityTimeout
}
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - maxExtraLen
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit
sesh.connReceiveBufferSize = 20480 // for backwards compatibility
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
b := make([]byte, sesh.streamSendBufferSize)
return &b
}}
sesh.sb = makeSwitchboard(sesh)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
sesh.sb = makeSwitchboard(sesh, valve)
return sesh
}
func (sesh *Session) GetSessionKey() [32]byte {
return sesh.sessionKey
}
func (sesh *Session) streamCountIncr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, 1)
}
func (sesh *Session) streamCountDecr() uint32 {
return atomic.AddUint32(&sesh.activeStreamCount, ^uint32(0))
}
func (sesh *Session) streamCount() uint32 {
return atomic.LoadUint32(&sesh.activeStreamCount)
}
// AddConnection is used to add an underlying connection to the connection pool
func (sesh *Session) AddConnection(conn net.Conn) {
sesh.sb.addConn(conn)
addrs := []net.Addr{conn.LocalAddr(), conn.RemoteAddr()}
sesh.addrs.Store(addrs)
}
// OpenStream is similar to net.Dial. It opens up a new stream
func (sesh *Session) OpenStream() (*Stream, error) {
if sesh.IsClosed() {
select {
case <-sesh.die:
return nil, ErrBrokenSession
default:
}
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
// Because atomic.AddUint32 returns the value after incrementation
if sesh.Singleplex && id > 1 {
// if there are more than one streams, which shouldn't happen if we are
// singleplexing
return nil, errNoMultiplex
}
stream := makeStream(sesh, id)
stream := makeStream(id, sesh)
sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id)
log.Printf("Opening stream %v\n", id)
return stream, nil
}
// Accept is similar to net.Listener's Accept(). It blocks and returns an incoming stream
func (sesh *Session) Accept() (net.Conn, error) {
if sesh.IsClosed() {
func (sesh *Session) AcceptStream() (*Stream, error) {
select {
case <-sesh.die:
return nil, ErrBrokenSession
}
stream := <-sesh.acceptCh
if stream == nil {
return nil, ErrBrokenSession
}
log.Tracef("stream %v of session %v accepted", stream.id, sesh.id)
case stream := <-sesh.acceptCh:
return stream, nil
}
func (sesh *Session) closeStream(s *Stream, active bool) error {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
}
_ = s.recvBuf.Close() // recvBuf.Close should not return error
if active {
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
// Notify remote that this stream is closed
common.CryptoRandRead((*tmpBuf)[:1])
padLen := int((*tmpBuf)[0]) + 1
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
// must be holding s.wirtingM on entry
s.writingFrame.Closing = closingStream
s.writingFrame.Payload = payload
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
sesh.streamObfsBufPool.Put(tmpBuf)
if err != nil {
return err
}
log.Tracef("stream %v actively closed.", s.id)
} else {
log.Tracef("stream %v passively closed", s.id)
}
// We set it as nil to signify that the stream id had existed before.
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
// if the frame it received was from a new stream or a dying stream whose frame arrived late
func (sesh *Session) delStream(id uint32) {
sesh.streamsM.Lock()
sesh.streams[s.id] = nil
sesh.streamsM.Unlock()
if sesh.streamCountDecr() == 0 {
if sesh.Singleplex {
return sesh.Close()
} else {
log.Debugf("session %v has no active stream left", sesh.id)
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
}
}
return nil
}
// recvDataFromRemote deobfuscate the frame and read the Closing field. If it is a closing frame, it writes the frame
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
// stream and then writes to the stream buffer
func (sesh *Session) recvDataFromRemote(data []byte) error {
frame := sesh.recvFramePool.Get().(*Frame)
defer sesh.recvFramePool.Put(frame)
err := sesh.deobfuscate(frame, data)
if err != nil {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
}
if frame.Closing == closingSession {
sesh.SetTerminalMsg("Received a closing notification frame")
return sesh.passiveClose()
}
sesh.streamsM.Lock()
if sesh.IsClosed() {
sesh.streamsM.Unlock()
return ErrBrokenSession
}
existingStream, existing := sesh.streams[frame.StreamID]
if existing {
sesh.streamsM.Unlock()
if existingStream == nil {
// this is when the stream existed before but has since been closed. We do nothing
return nil
}
return existingStream.recvFrame(frame)
} else {
newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream
sesh.acceptCh <- newStream
sesh.streamsM.Unlock()
// new stream
sesh.streamCountIncr()
return newStream.recvFrame(frame)
}
}
func (sesh *Session) SetTerminalMsg(msg string) {
log.Debug("terminal message set to " + msg)
sesh.terminalMsgSetter.Do(func() {
sesh.terminalMsg = msg
})
}
func (sesh *Session) TerminalMsg() string {
return sesh.terminalMsg
}
func (sesh *Session) closeSession() error {
if !atomic.CompareAndSwapUint32(&sesh.closed, 0, 1) {
log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing
}
sesh.streamsM.Lock()
close(sesh.acceptCh)
for id, stream := range sesh.streams {
if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
_ = stream.recvBuf.Close() // will not block
delete(sesh.streams, id)
sesh.streamCountDecr()
}
}
sesh.streamsM.Unlock()
return nil
}
func (sesh *Session) passiveClose() error {
log.Debugf("attempting to passively close session %v", sesh.id)
err := sesh.closeSession()
if err != nil {
return err
func (sesh *Session) isStream(id uint32) bool {
sesh.streamsM.RLock()
_, ok := sesh.streams[id]
sesh.streamsM.RUnlock()
return ok
}
sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id)
// If the stream has been closed and the triggering frame is a closing frame,
// we return nil
func (sesh *Session) getOrAddStream(id uint32, closingFrame bool) *Stream {
// it would have been neater to use defer Unlock(), however it gives
// non-negligable overhead and this function is performance critical
sesh.streamsM.Lock()
stream := sesh.streams[id]
if stream != nil {
sesh.streamsM.Unlock()
return stream
} else {
if closingFrame {
sesh.streamsM.Unlock()
return nil
} else {
stream = makeStream(id, sesh)
sesh.streams[id] = stream
sesh.acceptCh <- stream
log.Printf("Adding stream %v\n", id)
sesh.streamsM.Unlock()
return stream
}
}
}
func (sesh *Session) getStream(id uint32) *Stream {
sesh.streamsM.RLock()
ret := sesh.streams[id]
sesh.streamsM.RUnlock()
return ret
}
// addStream is used when the remote opened a new stream and we got notified
func (sesh *Session) addStream(id uint32) *Stream {
stream := makeStream(id, sesh)
sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.acceptCh <- stream
log.Printf("Adding stream %v\n", id)
return stream
}
func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id)
err := sesh.closeSession()
if err != nil {
return err
// Because closing a closed channel causes panic
sesh.suicide.Do(func() { close(sesh.die) })
sesh.streamsM.Lock()
for id, stream := range sesh.streams {
// If we call stream.Close() here, streamsM will result in a deadlock
// because stream.Close calls sesh.delStream, which locks the mutex.
// so we need to implement a method of stream that closes the stream without calling
// sesh.delStream
// This can also be seen in smux
go stream.closeNoDelMap()
delete(sesh.streams, id)
}
// we send a notice frame telling remote to close the session
sesh.streamsM.Unlock()
buf := sesh.streamObfsBufPool.Get().(*[]byte)
common.CryptoRandRead((*buf)[:1])
padLen := int((*buf)[0]) + 1
payload := (*buf)[frameHeaderLength : padLen+frameHeaderLength]
common.CryptoRandRead(payload)
f := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: closingSession,
Payload: payload,
}
i, err := sesh.obfuscate(f, *buf, frameHeaderLength)
if err != nil {
return err
}
_, err = sesh.sb.send((*buf)[:i], new(net.Conn))
if err != nil {
return err
}
sesh.sb.closeAll()
log.Debugf("session %v closed gracefully", sesh.id)
sesh.sb.shutdown()
return nil
}
func (sesh *Session) IsClosed() bool {
return atomic.LoadUint32(&sesh.closed) == 1
}
func (sesh *Session) checkTimeout() {
if sesh.streamCount() == 0 && !sesh.IsClosed() {
sesh.SetTerminalMsg("timeout")
sesh.Close()
}
}
func (sesh *Session) Addr() net.Addr { return sesh.addrs.Load().([]net.Addr)[0] }

View File

@ -1,24 +0,0 @@
//go:build gofuzz
// +build gofuzz
package multiplex
func setupSesh_fuzz(unordered bool) *Session {
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, [32]byte{})
seshConfig := SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: unordered,
}
return MakeSession(0, seshConfig)
}
func Fuzz(data []byte) int {
sesh := setupSesh_fuzz(false)
err := sesh.recvDataFromRemote(data)
if err == nil {
return 1
}
return 0
}

View File

@ -1,640 +0,0 @@
package multiplex
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
var seshConfigs = map[string]SessionConfig{
"ordered": {},
"unordered": {Unordered: true},
}
var encryptionMethods = map[string]byte{
"plain": EncryptionMethodPlain,
"aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
const testPayloadLen = 1024
const obfsBufLen = testPayloadLen * 2
func TestRecvDataFromRemote(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
var err error
seshConfig.Obfuscator, err = MakeObfuscator(EncryptionMethodPlain, sessionKey)
if err != nil {
t.Fatalf("failed to make obfuscator: %v", err)
}
t.Run("initial frame", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
make([]byte, testPayloadLen),
}
rand.Read(f.Payload)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
t.Run("two frames in order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
make([]byte, testPayloadLen),
}
rand.Read(f.Payload)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
f.Seq += 1
rand.Read(f.Payload)
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
t.Run("two frames in order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
make([]byte, testPayloadLen),
}
rand.Read(f.Payload)
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
f.Seq += 1
rand.Read(f.Payload)
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, f.Payload, resultPayload)
})
if seshType == "ordered" {
t.Run("frames out of order", func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
obfsBuf := make([]byte, obfsBufLen)
f := Frame{
1,
0,
0,
nil,
}
// First frame
seq0 := make([]byte, testPayloadLen)
rand.Read(seq0)
f.Seq = 0
f.Payload = seq0
n, err := sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Third frame
seq2 := make([]byte, testPayloadLen)
rand.Read(seq2)
f.Seq = 2
f.Payload = seq2
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Second frame
seq1 := make([]byte, testPayloadLen)
rand.Read(seq1)
f.Seq = 1
f.Payload = seq1
n, err = sesh.obfuscate(&f, obfsBuf, 0)
assert.NoError(t, err)
err = sesh.recvDataFromRemote(obfsBuf[:n])
assert.NoError(t, err)
// Expect things to receive in order
stream, err := sesh.Accept()
assert.NoError(t, err)
resultPayload := make([]byte, testPayloadLen)
// First
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq0, resultPayload)
// Second
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq1, resultPayload)
// Third
_, err = io.ReadFull(stream, resultPayload)
assert.NoError(t, err)
assert.EqualValues(t, seq2, resultPayload)
})
}
})
}
}
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte
rand.Read(sessionKey[:])
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
sesh := MakeSession(0, seshConfig)
f1 := &Frame{
1,
0,
closingNothing,
testPayload,
}
// create stream 1
n, _ := sesh.obfuscate(f1, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
_, ok := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("failed to fetch stream 1 after receiving it")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1")
}
// create stream 2
f2 := &Frame{
2,
0,
closingNothing,
testPayload,
}
n, _ = sesh.obfuscate(f2, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err)
}
sesh.streamsM.Lock()
s2M, ok := sesh.streams[f2.StreamID]
sesh.streamsM.Unlock()
if s2M == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it")
}
if sesh.streamCount() != 2 {
t.Error("stream count isn't 2")
}
// close stream 1
f1CloseStream := &Frame{
1,
1,
closingStream,
testPayload,
}
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
s1M, _ := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Fatal("stream 1 still exist after receiving stream close")
}
s1, _ := sesh.Accept()
if !s1.(*Stream).isClosed() {
t.Fatal("stream 1 not marked as closed")
}
payloadBuf := make([]byte, testPayloadLen)
_, err = s1.Read(payloadBuf)
if err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Fatalf("failed to read from stream 1 after closing: %v", err)
}
s2, _ := sesh.Accept()
if s2.(*Stream).isClosed() {
t.Fatal("stream 2 shouldn't be closed")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 closed")
}
// close stream 1 again
n, _ = sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
}
sesh.streamsM.Lock()
s1M, _ = sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Error("stream 1 exists after receiving stream close for the second time")
}
streamCount := sesh.streamCount()
if streamCount != 1 {
t.Errorf("stream count is %v after stream 1 closed twice, expected 1", streamCount)
}
// close session
fCloseSession := &Frame{
StreamID: 0xffffffff,
Seq: 0,
Closing: closingSession,
Payload: testPayload,
}
n, _ = sesh.obfuscate(fCloseSession, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving session closing frame: %v", err)
}
if !sesh.IsClosed() {
t.Error("session not closed after receiving signal")
}
if !s2.(*Stream).isClosed() {
t.Error("stream 2 isn't closed after session closed")
}
if _, err := s2.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from stream 2 after session closed")
}
if _, err := s2.Write(testPayload); err == nil {
t.Error("can still write to stream 2 after session closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after session closed")
}
}
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
// Tests for when the closing frame of a stream is received first before any data frame
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte
rand.Read(sessionKey[:])
seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator, _ = MakeObfuscator(EncryptionMethodPlain, sessionKey)
sesh := MakeSession(0, seshConfig)
// receive stream 1 closing first
f1CloseStream := &Frame{
1,
1,
closingStream,
testPayload,
}
n, _ := sesh.obfuscate(f1CloseStream, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
}
sesh.streamsM.Lock()
_, ok := sesh.streams[f1CloseStream.StreamID]
sesh.streamsM.Unlock()
if !ok {
t.Fatal("stream 1 doesn't exist")
}
if sesh.streamCount() != 1 {
t.Error("stream count isn't 1 after stream 1 received")
}
// receive data frame of stream 1 after receiving the closing frame
f1 := &Frame{
1,
0,
closingNothing,
testPayload,
}
n, _ = sesh.obfuscate(f1, obfsBuf, 0)
err = sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err)
}
s1, err := sesh.Accept()
if err != nil {
t.Fatal("failed to accept stream 1 after receiving it")
}
payloadBuf := make([]byte, testPayloadLen)
if _, err := s1.Read(payloadBuf); err != nil || !bytes.Equal(payloadBuf, testPayload) {
t.Error("failed to read from steam 1")
}
if !s1.(*Stream).isClosed() {
t.Error("s1 isn't closed")
}
if sesh.streamCount() != 0 {
t.Error("stream count isn't 0 after stream 1 closed")
}
}
func TestParallelStreams(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
numStreams := acceptBacklog
seqs := make([]*uint64, numStreams)
for i := range seqs {
seqs[i] = new(uint64)
}
randFrame := func() *Frame {
id := rand.Intn(numStreams)
return &Frame{
uint32(id),
atomic.AddUint64(seqs[id], 1) - 1,
uint8(rand.Intn(2)),
[]byte{1, 2, 3, 4},
}
}
const numOfTests = 5000
tests := make([]struct {
name string
frame *Frame
}, numOfTests)
for i := range tests {
tests[i].name = strconv.Itoa(i)
tests[i].frame = randFrame()
}
var wg sync.WaitGroup
for _, tc := range tests {
wg.Add(1)
go func(frame *Frame) {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(frame, obfsBuf, 0)
obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(obfsBuf)
if err != nil {
t.Error(err)
}
wg.Done()
}(tc.frame)
}
wg.Wait()
sc := int(sesh.streamCount())
var count int
sesh.streamsM.Lock()
for _, s := range sesh.streams {
if s != nil {
count++
}
}
sesh.streamsM.Unlock()
if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
}
})
}
}
func TestStream_SetReadDeadline(t *testing.T) {
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
_, err := stream.Read(make([]byte, 1))
if err != ErrTimeout {
t.Errorf("expecting error %v, got %v", ErrTimeout, err)
}
})
t.Run("unblock when deadline passed", func(t *testing.T) {
stream, _ := sesh.OpenStream()
done := make(chan struct{})
go func() {
_, _ = stream.Read(make([]byte, 1))
done <- struct{}{}
}()
_ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("Read did not unblock after deadline has passed")
}
})
})
}
}
func TestSession_timeoutAfter(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
seshConfig.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfig)
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
})
}
}
func BenchmarkRecvDataFromRemote(b *testing.B) {
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
f := Frame{
1,
0,
0,
testPayload,
}
var sessionKey [32]byte
rand.Read(sessionKey[:])
const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic
for name, ep := range encryptionMethods {
ep := ep
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
f := f
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig)
go func() {
stream, _ := sesh.Accept()
io.Copy(ioutil.Discard, stream)
}()
binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(&f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n]
f.Seq++
}
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(binaryFrames[i])
}
})
}
})
}
}
func BenchmarkMultiStreamWrite(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
testPayload := make([]byte, testPayloadLen)
for name, ep := range encryptionMethods {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
b.ResetTimer()
b.SetBytes(testPayloadLen)
b.RunParallel(func(pb *testing.PB) {
stream, _ := sesh.OpenStream()
for pb.Next() {
stream.Write(testPayload)
}
})
})
}
})
}
}
func BenchmarkLatency(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
for name, ep := range encryptionMethods {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
clientSesh := MakeSession(0, seshConfig)
serverSesh := MakeSession(0, seshConfig)
c, s := net.Pipe()
clientSesh.AddConnection(c)
serverSesh.AddConnection(s)
buf := make([]byte, 64)
clientStream, _ := clientSesh.OpenStream()
clientStream.Write(buf)
serverStream, _ := serverSesh.Accept()
io.ReadFull(serverStream, buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
clientStream.Write(buf)
io.ReadFull(serverStream, buf)
}
})
}
})
}
}

View File

@ -2,210 +2,156 @@ package multiplex
import (
"errors"
"io"
"net"
"time"
"log"
"math"
prand "math/rand"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
)
var ErrBrokenStream = errors.New("broken stream")
var errBrokenStream = errors.New("broken stream")
// Stream implements net.Conn. It represents an optionally-ordered, full-duplex, self-contained connection.
// If the session it belongs to runs in ordered mode, it provides ordering guarantee regardless of the underlying
// connection used.
// If the underlying connections the session uses are reliable, Stream is reliable. If they are not, Stream does not
// guarantee reliability.
type Stream struct {
id uint32
session *Session
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
// been read by the consumer through Read or WriteTo.
recvBuf recvBuffer
// Explanations of the following 4 fields can be found in frameSorter.go
nextRecvSeq uint32
rev int
sh sorterHeap
wrapMode bool
writingM sync.Mutex
writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom
// New frames are received through newFrameCh by frameSorter
newFrameCh chan *Frame
// sortedBufCh are order-sorted data ready to be read raw
sortedBufCh chan []byte
// atomic
closed uint32
nextSendSeq uint32
// When we want order guarantee (i.e. session.Unordered is false),
// we assign each stream a fixed underlying connection.
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
// so it won't have to use its priority queue to sort it.
// This is not used in unordered connection mode
assignedConn net.Conn
writingM sync.RWMutex
readFromTimeout time.Duration
// close(die) is used to notify different goroutines that this stream is closing
die chan struct{}
heliumMask sync.Once // my personal fav
}
func makeStream(sesh *Session, id uint32) *Stream {
func makeStream(id uint32, sesh *Session) *Stream {
stream := &Stream{
id: id,
session: sesh,
writingFrame: Frame{
StreamID: id,
Seq: 0,
Closing: closingNothing,
},
die: make(chan struct{}),
sh: []*frameNode{},
newFrameCh: make(chan *Frame, 1024),
sortedBufCh: make(chan []byte, 1024),
}
if sesh.Unordered {
stream.recvBuf = NewDatagramBufferedPipe()
} else {
stream.recvBuf = NewStreamBuffer()
}
go stream.recvNewFrame()
return stream
}
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
// receive a readily deobfuscated Frame so its payload can later be Read
func (s *Stream) recvFrame(frame *Frame) error {
toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed {
err = s.passiveClose()
if errors.Is(err, errRepeatStreamClosing) {
log.Debug(err)
return nil
}
return err
}
return err
}
// Read implements io.Read
func (s *Stream) Read(buf []byte) (n int, err error) {
//log.Tracef("attempting to read from stream %v", s.id)
func (stream *Stream) Read(buf []byte) (n int, err error) {
if len(buf) == 0 {
select {
case <-stream.die:
return 0, errBrokenStream
default:
return 0, nil
}
n, err = s.recvBuf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF {
return n, ErrBrokenStream
}
return
select {
case <-stream.die:
return 0, errBrokenStream
case data := <-stream.sortedBufCh:
if len(data) == 0 {
stream.passiveClose()
return 0, errBrokenStream
}
if len(buf) < len(data) {
log.Println(len(data))
return 0, errors.New("buf too small")
}
copy(buf, data)
return len(data), nil
}
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
cipherTextLen, err := s.session.obfuscate(&s.writingFrame, buf, payloadOffsetInBuf)
s.writingFrame.Seq++
}
func (stream *Stream) Write(in []byte) (n int, err error) {
// RWMutex used here isn't really for RW.
// we use it to exploit the fact that RLock doesn't create contention.
// The use of RWMutex is so that the stream will not actively close
// in the middle of the execution of Write. This may cause the closing frame
// to be sent before the data frame and cause loss of packet.
stream.writingM.RLock()
select {
case <-stream.die:
stream.writingM.RUnlock()
return 0, errBrokenStream
default:
}
f := &Frame{
StreamID: stream.id,
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
Closing: 0,
Payload: in,
}
tlsRecord, err := stream.session.obfs(f)
if err != nil {
return err
return 0, err
}
n, err = stream.session.sb.send(tlsRecord)
stream.writingM.RUnlock()
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
}
return err
}
return nil
}
// Write implements io.Write
func (s *Stream) Write(in []byte) (n int, err error) {
s.writingM.Lock()
defer s.writingM.Unlock()
if s.isClosed() {
return 0, ErrBrokenStream
}
for n < len(in) {
var framePayload []byte
if len(in)-n <= s.session.maxStreamUnitWrite {
// if we can fit remaining data of in into one frame
framePayload = in[n:]
} else {
// if we have to split
if s.session.Unordered {
// but we are not allowed to
err = io.ErrShortBuffer
return
}
framePayload = in[n : s.session.maxStreamUnitWrite+n]
}
s.writingFrame.Payload = framePayload
buf := s.session.streamObfsBufPool.Get().(*[]byte)
err = s.obfuscateAndSend(*buf, 0)
s.session.streamObfsBufPool.Put(buf)
if err != nil {
return
}
n += len(framePayload)
}
return
}
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
// for readFromTimeout amount of time
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
for {
if s.readFromTimeout != 0 {
if rder, ok := r.(net.Conn); !ok {
log.Warn("ReadFrom timeout is set but reader doesn't implement SetReadDeadline")
} else {
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
}
}
buf := s.session.streamObfsBufPool.Get().(*[]byte)
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
if er != nil {
return n, er
}
// the above read may have been unblocked by another goroutine calling stream.Close(), so we need
// to check that here
if s.isClosed() {
return n, ErrBrokenStream
}
s.writingM.Lock()
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
err = s.obfuscateAndSend(*buf, frameHeaderLength)
s.writingM.Unlock()
s.session.streamObfsBufPool.Put(buf)
if err != nil {
return
}
n += int64(read)
}
}
func (s *Stream) passiveClose() error {
return s.session.closeStream(s, false)
// only close locally. Used when the stream close is notified by the remote
func (stream *Stream) passiveClose() {
stream.heliumMask.Do(func() { close(stream.die) })
stream.session.delStream(stream.id)
log.Printf("%v passive closing\n", stream.id)
}
// active close. Close locally and tell the remote that this stream is being closed
func (s *Stream) Close() error {
s.writingM.Lock()
defer s.writingM.Unlock()
func (stream *Stream) Close() error {
return s.session.closeStream(s, true)
stream.writingM.Lock()
select {
case <-stream.die:
stream.writingM.Unlock()
return errors.New("Already Closed")
default:
}
stream.heliumMask.Do(func() { close(stream.die) })
// Notify remote that this stream is closed
prand.Seed(int64(stream.id))
padLen := int(math.Floor(prand.Float64()*200 + 300))
pad := make([]byte, padLen)
prand.Read(pad)
f := &Frame{
StreamID: stream.id,
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
Closing: 1,
Payload: pad,
}
tlsRecord, _ := stream.session.obfs(f)
stream.session.sb.send(tlsRecord)
stream.session.delStream(stream.id)
log.Printf("%v actively closed\n", stream.id)
stream.writingM.Unlock()
return nil
}
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
var errNotImplemented = errors.New("Not implemented")
// the following functions are purely for implementing net.Conn interface.
// they are not used
// TODO: implement the following
func (s *Stream) SetDeadline(t time.Time) error { return errNotImplemented }
func (s *Stream) SetWriteDeadline(t time.Time) error { return errNotImplemented }
// Same as passiveClose() but no call to session.delStream.
// This is called in session.Close() to avoid mutex deadlock
// We don't notify the remote because session.Close() is always
// called when the session is passively closed
func (stream *Stream) closeNoDelMap() {
stream.heliumMask.Do(func() { close(stream.die) })
}

View File

@ -1,111 +0,0 @@
package multiplex
// The data is multiplexed through several TCP connections, therefore the
// order of arrival is not guaranteed. A stream's first packet may be sent through
// connection0 and its second packet may be sent through connection1. Although both
// packets are transmitted reliably (as TCP is reliable), packet1 may arrive to the
// remote side before packet0. Cloak have to therefore sequence the packets so that they
// arrive in order as they were sent by the proxy software
//
// Cloak packets will have a 64-bit sequence number on them, so we know in which order
// they should be sent to the proxy software. The code in this file provides buffering and sorting.
import (
"container/heap"
"fmt"
"sync"
"time"
)
type sorterHeap []*Frame
func (sh sorterHeap) Less(i, j int) bool {
return sh[i].Seq < sh[j].Seq
}
func (sh sorterHeap) Len() int {
return len(sh)
}
func (sh sorterHeap) Swap(i, j int) {
sh[i], sh[j] = sh[j], sh[i]
}
func (sh *sorterHeap) Push(x interface{}) {
*sh = append(*sh, x.(*Frame))
}
func (sh *sorterHeap) Pop() interface{} {
old := *sh
n := len(old)
x := old[n-1]
*sh = old[0 : n-1]
return x
}
type streamBuffer struct {
recvM sync.Mutex
nextRecvSeq uint64
sh sorterHeap
buf *streamBufferedPipe
}
// streamBuffer is a wrapper around streamBufferedPipe.
// Its main function is to sort frames in order, and wait for frames to arrive
// if they have arrived out-of-order. Then it writes the payload of frames into
// a streamBufferedPipe.
func NewStreamBuffer() *streamBuffer {
sb := &streamBuffer{
sh: []*Frame{},
buf: NewStreamBufferedPipe(),
}
return sb
}
func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) {
sb.recvM.Lock()
defer sb.recvM.Unlock()
// when there'fs no ooo packages in heap and we receive the next package in order
if len(sb.sh) == 0 && f.Seq == sb.nextRecvSeq {
if f.Closing != closingNothing {
return true, nil
} else {
sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1
}
return false, nil
}
if f.Seq < sb.nextRecvSeq {
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
}
saved := *f
saved.Payload = make([]byte, len(f.Payload))
copy(saved.Payload, f.Payload)
heap.Push(&sb.sh, &saved)
// Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
f = heap.Pop(&sb.sh).(*Frame)
if f.Closing != closingNothing {
return true, nil
} else {
sb.buf.Write(f.Payload)
sb.nextRecvSeq += 1
}
}
return false, nil
}
func (sb *streamBuffer) Read(buf []byte) (int, error) {
return sb.buf.Read(buf)
}
func (sb *streamBuffer) Close() error {
sb.recvM.Lock()
defer sb.recvM.Unlock()
return sb.buf.Close()
}
func (sb *streamBuffer) SetReadDeadline(t time.Time) { sb.buf.SetReadDeadline(t) }

View File

@ -1,91 +0,0 @@
package multiplex
import (
"encoding/binary"
"io"
//"log"
"sort"
"testing"
)
func TestRecvNewFrame(t *testing.T) {
inOrder := []uint64{5, 6, 7, 8, 9, 10, 11}
outOfOrder0 := []uint64{5, 7, 8, 6, 11, 10, 9}
outOfOrder1 := []uint64{1, 96, 47, 2, 29, 18, 60, 8, 74, 22, 82, 58, 44, 51, 57, 71, 90, 94, 68, 83, 61, 91, 39, 97, 85, 63, 46, 73, 54, 84, 76, 98, 93, 79, 75, 50, 67, 37, 92, 99, 42, 77, 17, 16, 38, 3, 100, 24, 31, 7, 36, 40, 86, 64, 34, 45, 12, 5, 9, 27, 21, 26, 35, 6, 65, 69, 53, 4, 48, 28, 30, 56, 32, 11, 80, 66, 25, 41, 78, 13, 88, 62, 15, 70, 49, 43, 72, 23, 10, 55, 52, 95, 14, 59, 87, 33, 19, 20, 81, 89}
outOfOrder2 := []uint64{1<<32 - 5, 1<<32 + 3, 1 << 32, 1<<32 - 3, 1<<32 - 4, 1<<32 + 2, 1<<32 - 2, 1<<32 - 1, 1<<32 + 1}
test := func(set []uint64, ct *testing.T) {
sb := NewStreamBuffer()
sb.nextRecvSeq = set[0]
for _, n := range set {
bu64 := make([]byte, 8)
binary.BigEndian.PutUint64(bu64, n)
sb.Write(&Frame{
Seq: n,
Payload: bu64,
})
}
var sortedResult []uint64
for x := 0; x < len(set); x++ {
oct := make([]byte, 8)
n, err := sb.Read(oct)
if n != 8 || err != nil {
ct.Error("failed to read from sorted Buf", n, err)
return
}
//log.Print(p)
sortedResult = append(sortedResult, binary.BigEndian.Uint64(oct))
}
targetSorted := make([]uint64, len(set))
copy(targetSorted, set)
sort.Slice(targetSorted, func(i, j int) bool { return targetSorted[i] < targetSorted[j] })
for i := range targetSorted {
if sortedResult[i] != targetSorted[i] {
goto fail
}
}
sb.Close()
return
fail:
ct.Error(
"expecting", targetSorted,
"got", sortedResult,
)
}
t.Run("in order", func(t *testing.T) {
test(inOrder, t)
})
t.Run("out of order0", func(t *testing.T) {
test(outOfOrder0, t)
})
t.Run("out of order1", func(t *testing.T) {
test(outOfOrder1, t)
})
t.Run("out of order wrap", func(t *testing.T) {
test(outOfOrder2, t)
})
}
func TestStreamBuffer_RecvThenClose(t *testing.T) {
const testDataLen = 128
sb := NewStreamBuffer()
testData := make([]byte, testDataLen)
testFrame := Frame{
StreamID: 0,
Seq: 0,
Closing: 0,
Payload: testData,
}
sb.Write(&testFrame)
sb.Close()
readBuf := make([]byte, testDataLen)
_, err := io.ReadFull(sb, readBuf)
if err != nil {
t.Error(err)
}
}

View File

@ -1,102 +0,0 @@
// This is base on https://github.com/golang/go/blob/0436b162397018c45068b47ca1b5924a3eafdee0/src/net/net_fake.go#L173
package multiplex
import (
"bytes"
"io"
"sync"
"time"
)
// The point of a streamBufferedPipe is that Read() will block until data is available
type streamBufferedPipe struct {
buf *bytes.Buffer
closed bool
rwCond *sync.Cond
rDeadline time.Time
wtTimeout time.Duration
timeoutTimer *time.Timer
}
func NewStreamBufferedPipe() *streamBufferedPipe {
p := &streamBufferedPipe{
rwCond: sync.NewCond(&sync.Mutex{}),
buf: new(bytes.Buffer),
}
return p
}
func (p *streamBufferedPipe) Read(target []byte) (int, error) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
for {
if p.closed && p.buf.Len() == 0 {
return 0, io.EOF
}
hasRDeadline := !p.rDeadline.IsZero()
if hasRDeadline {
if time.Until(p.rDeadline) <= 0 {
return 0, ErrTimeout
}
}
if p.buf.Len() > 0 {
break
}
if hasRDeadline {
p.broadcastAfter(time.Until(p.rDeadline))
}
p.rwCond.Wait()
}
n, err := p.buf.Read(target)
// err will always be nil because we have already verified that buf.Len() != 0
p.rwCond.Broadcast()
return n, err
}
func (p *streamBufferedPipe) Write(input []byte) (int, error) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
for {
if p.closed {
return 0, io.ErrClosedPipe
}
if p.buf.Len() <= recvBufferSizeLimit {
// if p.buf gets too large, write() will panic. We don't want this to happen
break
}
p.rwCond.Wait()
}
n, err := p.buf.Write(input)
// err will always be nil
p.rwCond.Broadcast()
return n, err
}
func (p *streamBufferedPipe) Close() error {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
p.closed = true
p.rwCond.Broadcast()
return nil
}
func (p *streamBufferedPipe) SetReadDeadline(t time.Time) {
p.rwCond.L.Lock()
defer p.rwCond.L.Unlock()
p.rDeadline = t
p.rwCond.Broadcast()
}
func (p *streamBufferedPipe) broadcastAfter(d time.Duration) {
if p.timeoutTimer != nil {
p.timeoutTimer.Stop()
}
p.timeoutTimer = time.AfterFunc(d, p.rwCond.Broadcast)
}

View File

@ -1,93 +0,0 @@
package multiplex
import (
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const readBlockTime = 500 * time.Millisecond
func TestPipeRW(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
assert.NoError(t, err, "simple write")
assert.Equal(t, len(b), n, "number of bytes written")
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func TestReadBlock(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
go func() {
time.Sleep(readBlockTime)
pipe.Write(b)
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
assert.NoError(t, err, "blocked read")
assert.Equal(t, len(b), n, "number of bytes read after block")
assert.Equal(t, b, b2)
}
func TestPartialRead(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b1 := make([]byte, 1)
n, err := pipe.Read(b1)
assert.NoError(t, err, "partial read of 1")
assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
assert.Equal(t, b[0], b1[0])
b2 := make([]byte, 2)
n, err = pipe.Read(b2)
assert.NoError(t, err, "partial read of 2")
assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
assert.Equal(t, b[1:], b2)
}
func TestReadAfterClose(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
pipe.Write(b)
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func BenchmarkBufferedPipe_RW(b *testing.B) {
const PAYLOAD_LEN = 1000
testData := make([]byte, PAYLOAD_LEN)
rand.Read(testData)
pipe := NewStreamBufferedPipe()
smallBuf := make([]byte, PAYLOAD_LEN-10)
go func() {
for {
pipe.Read(smallBuf)
}
}()
b.SetBytes(int64(len(testData)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pipe.Write(testData)
}
}

View File

@ -1,388 +0,0 @@
package multiplex
import (
"bytes"
"io"
"math/rand"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
"github.com/cbeuw/connutil"
)
const payloadLen = 1000
var emptyKey [32]byte
func setupSesh(unordered bool, key [32]byte, encryptionMethod byte) *Session {
obfuscator, _ := MakeObfuscator(encryptionMethod, key)
seshConfig := SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: unordered,
}
return MakeSession(0, seshConfig)
}
func BenchmarkStream_Write_Ordered(b *testing.B) {
hole := connutil.Discard()
var sessionKey [32]byte
rand.Read(sessionKey[:])
const testDataLen = 65536
testData := make([]byte, testDataLen)
rand.Read(testData)
eMethods := map[string]byte{
"plain": EncryptionMethodPlain,
"chacha20-poly1305": EncryptionMethodChaha20Poly1305,
"aes-256-gcm": EncryptionMethodAES256GCM,
"aes-128-gcm": EncryptionMethodAES128GCM,
}
for name, method := range eMethods {
b.Run(name, func(b *testing.B) {
sesh := setupSesh(false, sessionKey, method)
sesh.AddConnection(hole)
stream, _ := sesh.OpenStream()
b.SetBytes(testDataLen)
b.ResetTimer()
for i := 0; i < b.N; i++ {
stream.Write(testData)
}
})
}
}
func TestStream_Write(t *testing.T) {
hole := connutil.Discard()
var sessionKey [32]byte
rand.Read(sessionKey[:])
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
sesh.AddConnection(hole)
testData := make([]byte, payloadLen)
rand.Read(testData)
stream, _ := sesh.OpenStream()
_, err := stream.Write(testData)
if err != nil {
t.Error(
"For", "stream write",
"got", err,
)
}
}
func TestStream_WriteSync(t *testing.T) {
// Close calls made after write MUST have a higher seq
var sessionKey [32]byte
rand.Read(sessionKey[:])
clientSesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
serverSesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
w, r := connutil.AsyncPipe()
clientSesh.AddConnection(common.NewTLSConn(w))
serverSesh.AddConnection(common.NewTLSConn(r))
testData := make([]byte, payloadLen)
rand.Read(testData)
t.Run("test single", func(t *testing.T) {
go func() {
stream, _ := clientSesh.OpenStream()
stream.Write(testData)
stream.Close()
}()
recvBuf := make([]byte, payloadLen)
serverStream, _ := serverSesh.Accept()
_, err := io.ReadFull(serverStream, recvBuf)
if err != nil {
t.Error(err)
}
})
t.Run("test multiple", func(t *testing.T) {
const numStreams = 100
for i := 0; i < numStreams; i++ {
go func() {
stream, _ := clientSesh.OpenStream()
stream.Write(testData)
stream.Close()
}()
}
for i := 0; i < numStreams; i++ {
recvBuf := make([]byte, payloadLen)
serverStream, _ := serverSesh.Accept()
_, err := io.ReadFull(serverStream, recvBuf)
if err != nil {
t.Error(err)
}
}
})
}
func TestStream_Close(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
testPayload := []byte{42, 42, 42}
dataFrame := &Frame{
1,
0,
0,
testPayload,
}
t.Run("active closing", func(t *testing.T) {
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
rawConn, rawWritingEnd := connutil.AsyncPipe()
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512)
i, _ := sesh.obfuscate(dataFrame, obfsBuf, 0)
_, err := writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Error("failed to write from remote end")
}
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
time.Sleep(500 * time.Millisecond)
err = stream.Close()
if err != nil {
t.Error("failed to actively close stream", err)
return
}
sesh.streamsM.Lock()
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
sesh.streamsM.Unlock()
t.Error("stream still exists")
return
}
sesh.streamsM.Unlock()
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf)
if err != nil {
t.Errorf("cannot read resiual data: %v", err)
}
if !bytes.Equal(readBuf, testPayload) {
t.Errorf("read wrong data")
}
})
t.Run("passive closing", func(t *testing.T) {
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
rawConn, rawWritingEnd := connutil.AsyncPipe()
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
obfsBuf := make([]byte, 512)
i, err := sesh.obfuscate(dataFrame, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Error("failed to write from remote end")
}
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
closingFrame := &Frame{
1,
dataFrame.Seq + 1,
closingStream,
testPayload,
}
i, err = sesh.obfuscate(closingFrame, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Errorf("failed to write from remote end %v", err)
}
closingFrameDup := &Frame{
1,
dataFrame.Seq + 2,
closingStream,
testPayload,
}
i, err = sesh.obfuscate(closingFrameDup, obfsBuf, 0)
if err != nil {
t.Errorf("failed to obfuscate frame %v", err)
}
_, err = writingEnd.Write(obfsBuf[:i])
if err != nil {
t.Errorf("failed to write from remote end %v", err)
}
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf)
if err != nil {
t.Errorf("can't read residual data %v", err)
}
assert.Eventually(t, func() bool {
sesh.streamsM.Lock()
s, _ := sesh.streams[stream.(*Stream).id]
sesh.streamsM.Unlock()
return s == nil
}, time.Second, 10*time.Millisecond, "streams still exists")
})
}
func TestStream_Read(t *testing.T) {
seshes := map[string]bool{
"ordered": false,
"unordered": true,
}
testPayload := []byte{42, 42, 42}
const smallPayloadLen = 3
f := &Frame{
1,
0,
0,
testPayload,
}
var streamID uint32
for name, unordered := range seshes {
sesh := setupSesh(unordered, emptyKey, EncryptionMethodPlain)
rawConn, rawWritingEnd := connutil.AsyncPipe()
sesh.AddConnection(common.NewTLSConn(rawConn))
writingEnd := common.NewTLSConn(rawWritingEnd)
t.Run(name, func(t *testing.T) {
buf := make([]byte, 10)
obfsBuf := make([]byte, 512)
t.Run("Plain read", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++
writingEnd.Write(obfsBuf[:i])
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
i, err = stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
return
}
if i != smallPayloadLen {
t.Errorf("expected read %v, got %v", smallPayloadLen, i)
return
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
return
}
})
t.Run("Nil buf", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++
writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept()
i, err := stream.Read(nil)
if i != 0 || err != nil {
t.Error("expecting", 0, nil,
"got", i, err)
}
})
t.Run("Read after stream close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++
writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
stream.Close()
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
if err != nil {
t.Errorf("cannot read residual data: %v", err)
}
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
t.Error("expected", testPayload,
"got", buf[:smallPayloadLen])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
t.Run("Read after session close", func(t *testing.T) {
f.StreamID = streamID
i, _ := sesh.obfuscate(f, obfsBuf, 0)
streamID++
writingEnd.Write(obfsBuf[:i])
stream, _ := sesh.Accept()
time.Sleep(500 * time.Millisecond)
sesh.Close()
_, err := io.ReadFull(stream, buf[:smallPayloadLen])
if err != nil {
t.Errorf("cannot read resiual data: %v", err)
}
if !bytes.Equal(buf[:smallPayloadLen], testPayload) {
t.Error("expected", testPayload,
"got", buf[:smallPayloadLen])
}
_, err = stream.Read(buf)
if err == nil {
t.Error("expecting error", ErrBrokenStream,
"got nil error")
}
})
})
}
}
func TestStream_SetReadFromTimeout(t *testing.T) {
seshes := map[string]*Session{
"ordered": setupSesh(false, emptyKey, EncryptionMethodPlain),
"unordered": setupSesh(true, emptyKey, EncryptionMethodPlain),
}
for name, sesh := range seshes {
t.Run(name, func(t *testing.T) {
stream, _ := sesh.OpenStream()
stream.SetReadFromTimeout(100 * time.Millisecond)
done := make(chan struct{})
go func() {
stream.ReadFrom(connutil.Discard())
done <- struct{}{}
}()
select {
case <-done:
return
case <-time.After(500 * time.Millisecond):
t.Error("didn't timeout")
}
})
}
}

View File

@ -2,165 +2,189 @@ package multiplex
import (
"errors"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
"math/rand/v2"
"log"
"net"
"sync"
"sync/atomic"
)
type switchboardStrategy int
const (
fixedConnMapping switchboardStrategy = iota
uniformSpread
)
// switchboard represents the connection pool. It is responsible for managing
// transport-layer connections between client and server.
// It has several purposes: constantly receiving incoming data from all connections
// and pass them to Session.recvDataFromRemote(); accepting data through
// switchboard.send(), in which it selects a connection according to its
// switchboardStrategy and send the data off using that; and counting, as well as
// rate limiting, data received and sent through its Valve.
// switchboard is responsible for keeping the reference of TLS connections between client and server
type switchboard struct {
session *Session
valve Valve
strategy switchboardStrategy
*Valve
conns sync.Map
connsCount uint32
randPool sync.Pool
// optimum is the connEnclave with the smallest sendQueue
optimum atomic.Value // *connEnclave
cesM sync.RWMutex
ces []*connEnclave
broken uint32
/*
//debug
hM sync.Mutex
used map[uint32]bool
*/
}
func makeSwitchboard(sesh *Session) *switchboard {
func (sb *switchboard) getOptimum() *connEnclave {
if i := sb.optimum.Load(); i == nil {
return nil
} else {
return i.(*connEnclave)
}
}
func (sb *switchboard) setOptimum(ce *connEnclave) {
sb.optimum.Store(ce)
}
// Some data comes from a Stream to be sent through one of the many
// remoteConn, but which remoteConn should we use to send the data?
//
// In this case, we pick the remoteConn that has about the smallest sendQueue.
type connEnclave struct {
remoteConn net.Conn
sendQueue uint32
}
func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
// rates are uint64 because in the usermanager we want the bandwidth to be atomically
// operated (so that the bandwidth can change on the fly).
sb := &switchboard{
session: sesh,
strategy: uniformSpread,
valve: sesh.Valve,
randPool: sync.Pool{New: func() interface{} {
var state [32]byte
common.CryptoRandRead(state[:])
return rand.New(rand.NewChaCha8(state))
}},
Valve: valve,
ces: []*connEnclave{},
//debug
// used: make(map[uint32]bool),
}
return sb
}
var errBrokenSwitchboard = errors.New("the switchboard is broken")
var errNilOptimum error = errors.New("The optimal connection is nil")
func (sb *switchboard) addConn(conn net.Conn) {
connId := atomic.AddUint32(&sb.connsCount, 1) - 1
sb.conns.Store(connId, conn)
go sb.deplex(conn)
var ErrNoRxCredit error = errors.New("No Rx credit is left")
var ErrNoTxCredit error = errors.New("No Tx credit is left")
func (sb *switchboard) send(data []byte) (int, error) {
ce := sb.getOptimum()
if ce == nil {
return 0, errNilOptimum
}
// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable
func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) {
sb.valve.txWait(len(data))
if atomic.LoadUint32(&sb.broken) == 1 {
return 0, errBrokenSwitchboard
}
var conn net.Conn
switch sb.strategy {
case uniformSpread:
conn, err = sb.pickRandConn()
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
go sb.updateOptimum()
n, err := ce.remoteConn.Write(data)
if err != nil {
return 0, errBrokenSwitchboard
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
case fixedConnMapping:
// FIXME: this strategy has a tendency to cause a TLS conn socket buffer to fill up,
// which is a problem when multiple streams are mapped to the same conn, resulting
// in all such streams being blocked.
conn = *assignedConn
if conn == nil {
conn, err = sb.pickRandConn()
if err != nil {
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
sb.session.passiveClose()
return 0, err
sb.txWait(n)
if sb.AddTxCredit(-int64(n)) < 0 {
log.Println(ErrNoTxCredit)
defer sb.session.Close()
return n, ErrNoTxCredit
}
*assignedConn = conn
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
default:
return 0, errors.New("unsupported traffic distribution strategy")
}
sb.valve.AddTx(int64(n))
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
go sb.updateOptimum()
return n, nil
}
// returns a random conn. This function can be called concurrently.
func (sb *switchboard) pickRandConn() (net.Conn, error) {
if atomic.LoadUint32(&sb.broken) == 1 {
return nil, errBrokenSwitchboard
func (sb *switchboard) updateOptimum() {
currentOpti := sb.getOptimum()
currentOptiQ := atomic.LoadUint32(&currentOpti.sendQueue)
sb.cesM.RLock()
for _, ce := range sb.ces {
ceQ := atomic.LoadUint32(&ce.sendQueue)
if ceQ < currentOptiQ {
currentOpti = ce
currentOptiQ = ceQ
}
}
sb.cesM.RUnlock()
sb.setOptimum(currentOpti)
}
connsCount := atomic.LoadUint32(&sb.connsCount)
if connsCount == 0 {
return nil, errBrokenSwitchboard
func (sb *switchboard) addConn(conn net.Conn) {
newCe := &connEnclave{
remoteConn: conn,
sendQueue: 0,
}
sb.cesM.Lock()
sb.ces = append(sb.ces, newCe)
sb.cesM.Unlock()
sb.setOptimum(newCe)
go sb.deplex(newCe)
}
randReader := sb.randPool.Get().(*rand.Rand)
connId := randReader.Uint32N(connsCount)
sb.randPool.Put(randReader)
ret, ok := sb.conns.Load(connId)
if !ok {
log.Errorf("failed to get conn %d", connId)
return nil, errBrokenSwitchboard
func (sb *switchboard) removeConn(closing *connEnclave) {
sb.cesM.Lock()
for i, ce := range sb.ces {
if closing == ce {
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
break
}
return ret.(net.Conn), nil
}
if len(sb.ces) == 0 {
sb.session.Close()
}
sb.cesM.Unlock()
}
// actively triggered by session.Close()
func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return
func (sb *switchboard) shutdown() {
for _, ce := range sb.ces {
ce.remoteConn.Close()
}
atomic.StoreUint32(&sb.connsCount, 0)
sb.conns.Range(func(_, conn interface{}) bool {
conn.(net.Conn).Close()
sb.conns.Delete(conn)
return true
})
}
// deplex function costantly reads from a TCP connection
func (sb *switchboard) deplex(conn net.Conn) {
defer conn.Close()
buf := make([]byte, sb.session.connReceiveBufferSize)
// deplex function costantly reads from a TCP connection, call deobfs and distribute it
// to the corresponding frame
func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 20480)
for {
n, err := conn.Read(buf)
sb.valve.rxWait(n)
sb.valve.AddRx(int64(n))
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
sb.rxWait(n)
if err != nil {
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
sb.session.passiveClose()
log.Println(err)
go ce.remoteConn.Close()
sb.removeConn(ce)
return
}
err = sb.session.recvDataFromRemote(buf[:n])
if sb.AddRxCredit(-int64(n)) < 0 {
log.Println(ErrNoRxCredit)
sb.session.Close()
return
}
frame, err := sb.session.deobfs(buf[:n])
if err != nil {
log.Error(err)
}
log.Println(err)
continue
}
// FIXME: there has been a bug in which a packet has
// a seemingly corrupted StreamID (e.g. when the largest streamID is something like 3000
// and suddently a streamID of 3358661675 is added.
// It happens once ~6 hours and the occourance is realy unstable
// I couldn't find a way to reproduce it. But I do have some clue.
// I commented out the util.genXorKeys function so that the stream headers are being
// sent in plaintext, and this bug didn't happen again. So I suspect it has to do
// with xxHash. Either it's to do with my usage of the libary or the implementation
// of the library. Maybe there's a race somewhere? I may eventually use another
// method to encrypt the headers. xxHash isn't cryptographic afterall.
stream := sb.session.getOrAddStream(frame.StreamID, frame.Closing == 1)
// if the frame is telling us to close a closed stream
// (this happens when ss-server and ss-local closes the stream
// simutaneously), we don't do anything
if stream != nil {
stream.writeNewFrame(frame)
}
//debug
/*
sb.hM.Lock()
if sb.used[frame.StreamID] {
log.Printf("%v lost!\n", frame.StreamID)
}
sb.used[frame.StreamID] = true
sb.hM.Unlock()
*/
}
}

View File

@ -1,187 +0,0 @@
package multiplex
import (
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
func TestSwitchboard_Send(t *testing.T) {
doTest := func(seshConfig SessionConfig) {
sesh := MakeSession(0, seshConfig)
hole0 := connutil.Discard()
sesh.sb.addConn(hole0)
conn, err := sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
data := make([]byte, 1000)
rand.Read(data)
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
hole1 := connutil.Discard()
sesh.sb.addConn(hole1)
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
}
t.Run("Ordered", func(t *testing.T) {
seshConfig := SessionConfig{
Unordered: false,
}
doTest(seshConfig)
})
t.Run("Unordered", func(t *testing.T) {
seshConfig := SessionConfig{
Unordered: true,
}
doTest(seshConfig)
})
}
func BenchmarkSwitchboard_Send(b *testing.B) {
hole := connutil.Discard()
seshConfig := SessionConfig{}
sesh := MakeSession(0, seshConfig)
sesh.sb.addConn(hole)
conn, err := sesh.sb.pickRandConn()
if err != nil {
b.Error("failed to get a random conn", err)
return
}
data := make([]byte, 1000)
rand.Read(data)
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.sb.send(data, &conn)
}
}
func TestSwitchboard_TxCredit(t *testing.T) {
seshConfig := SessionConfig{
Valve: MakeValve(1<<20, 1<<20),
}
sesh := MakeSession(0, seshConfig)
hole := connutil.Discard()
sesh.sb.addConn(hole)
conn, err := sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
data := make([]byte, 1000)
rand.Read(data)
t.Run("fixed conn mapping", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = fixedConnMapping
n, err := sesh.sb.send(data[:10], &conn)
if err != nil {
t.Error(err)
return
}
if n != 10 {
t.Errorf("wanted to send %v, got %v", 10, n)
return
}
if *sesh.sb.valve.(*LimitedValve).tx != 10 {
t.Error("tx credit didn't increase by 10")
}
})
t.Run("uniform spread", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = uniformSpread
n, err := sesh.sb.send(data[:10], &conn)
if err != nil {
t.Error(err)
return
}
if n != 10 {
t.Errorf("wanted to send %v, got %v", 10, n)
return
}
if *sesh.sb.valve.(*LimitedValve).tx != 10 {
t.Error("tx credit didn't increase by 10")
}
})
}
func TestSwitchboard_CloseOnOneDisconn(t *testing.T) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
sesh := setupSesh(false, sessionKey, EncryptionMethodPlain)
conn0client, conn0server := connutil.AsyncPipe()
sesh.AddConnection(conn0client)
conn1client, _ := connutil.AsyncPipe()
sesh.AddConnection(conn1client)
conn0server.Close()
assert.Eventually(t, func() bool {
return sesh.IsClosed()
}, time.Second, 10*time.Millisecond, "session not closed after one conn is disconnected")
if _, err := conn1client.Write([]byte{0x00}); err == nil {
t.Error("the other conn is still connected")
return
}
}
func TestSwitchboard_ConnsCount(t *testing.T) {
seshConfig := SessionConfig{
Valve: MakeValve(1<<20, 1<<20),
}
sesh := MakeSession(0, seshConfig)
var wg sync.WaitGroup
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
sesh.AddConnection(connutil.Discard())
wg.Done()
}()
}
wg.Wait()
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
t.Error("connsCount incorrect")
}
sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool {
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
}

View File

@ -1,101 +1,163 @@
package server
import (
"crypto"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/util"
)
const appDataMaxLength = 16401
// ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
type TLS struct{}
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := util.BtoInt(input[pointer : pointer+2])
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
// AddRecordLayer adds record layer to data
func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
func (TLS) String() string { return "TLS" }
// PeelRecordLayer peels off the record layer
func PeelRecordLayer(data []byte) []byte {
ret := data[5:]
return ret
}
func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
ch, err := parseClientHello(clientHello)
if err != nil {
log.Debug(err)
err = ErrBadClientHello
// ParseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func ParseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
data = PeelRecordLayer(data)
pointer := 0
// Handshake Type
handshakeType := data[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := util.BtoInt(data[pointer : pointer+3])
pointer += 3
if length != len(data[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := data[pointer : pointer+2]
pointer += 2
// Random
random := data[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(data[pointer])
pointer += 1
sessionId := data[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := util.BtoInt(data[pointer : pointer+2])
pointer += 2
cipherSuites := data[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(data[pointer])
pointer += 1
compressionMethods := data[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := util.BtoInt(data[pointer : pointer+2])
pointer += 2
extensions, err := parseExtensions(data[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
}
return
}
fragments, err = TLS{}.unmarshalClientHello(ch, privateKey)
if err != nil {
err = fmt.Errorf("failed to unmarshal ClientHello into authFragments: %v", err)
return
func composeServerHello(ch *ClientHello) []byte {
var serverHello [10][]byte
serverHello[0] = []byte{0x02} // handshake type
serverHello[1] = []byte{0x00, 0x00, 0x4d} // length 77
serverHello[2] = []byte{0x03, 0x03} // server version
serverHello[3] = util.PsudoRandBytes(32, time.Now().UnixNano()) // random
serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = ch.sessionId // session id
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x05} // extensions length 5
serverHello[9] = []byte{0xff, 0x01, 0x00, 0x01, 0x00} // extensions renegotiation_info
ret := []byte{}
for i := 0; i < 10; i++ {
ret = append(ret, serverHello[i]...)
}
return ret
}
respond = TLS{}.makeResponder(ch.sessionId, fragments.sharedSecret)
return
}
func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Responder {
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) {
// the cert length needs to be the same for all handshakes belonging to the same session
// we can use sessionKey as a seed here to ensure consistency
possibleCertLengths := []int{42, 27, 68, 59, 36, 44, 46}
cert := make([]byte, possibleCertLengths[common.RandInt(len(possibleCertLengths))])
common.RandRead(randSource, cert)
var nonce [12]byte
common.RandRead(randSource, nonce[:])
encryptedSessionKey, err := common.AESGCMEncrypt(nonce[:], sharedSecret[:], sessionKey[:])
if err != nil {
return
}
var encryptedSessionKeyArr [48]byte
copy(encryptedSessionKeyArr[:], encryptedSessionKey)
reply := composeReply(clientHelloSessionId, nonce, encryptedSessionKeyArr, cert)
_, err = originalConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write TLS reply: %v", err)
originalConn.Close()
return
}
preparedConn = common.NewTLSConn(originalConn)
return
}
return respond
}
func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fragments authFragments, err error) {
copy(fragments.randPubKey[:], ch.random)
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:])
if !ok {
err = ErrInvalidPubKey
return
}
var sharedSecret []byte
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
}
copy(fragments.sharedSecret[:], sharedSecret)
var keyShare []byte
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
if err != nil {
return
}
ctxTag := append(ch.sessionId, keyShare...)
if len(ctxTag) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(ctxTag))
return
}
copy(fragments.ciphertextWithTag[:], ctxTag)
return
// ComposeReply composes the ServerHello, ChangeCipherSpec and Finished messages
// together with their respective record layers into one byte slice. The content
// of these messages are random and useless for this plugin
func ComposeReply(ch *ClientHello) []byte {
TLS12 := []byte{0x03, 0x03}
shBytes := AddRecordLayer(composeServerHello(ch), []byte{0x16}, TLS12)
ccsBytes := AddRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
finished := make([]byte, 64)
finished = util.PsudoRandBytes(40, time.Now().UnixNano())
fBytes := AddRecordLayer(finished, []byte{0x16}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, fBytes...)
return ret
}

View File

@ -1,202 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/cbeuw/Cloak/internal/common"
)
// ClientHello contains every field in a ClientHello message
type ClientHello struct {
handshakeType byte
length int
clientVersion []byte
random []byte
sessionIdLen int
sessionId []byte
cipherSuitesLen int
cipherSuites []byte
compressionMethodsLen int
compressionMethods []byte
extensionsLen int
extensions map[[2]byte][]byte
}
var u16 = binary.BigEndian.Uint16
var u32 = binary.BigEndian.Uint32
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed Extensions")
}
}()
pointer := 0
totalLen := len(input)
ret = make(map[[2]byte][]byte)
for pointer < totalLen {
var typ [2]byte
copy(typ[:], input[pointer:pointer+2])
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
data := input[pointer : pointer+length]
pointer += length
ret[typ] = data
}
return ret, err
}
func parseKeyShare(input []byte) (ret []byte, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("malformed key_share")
}
}()
totalLen := int(u16(input[0:2]))
// 2 bytes "client key share length"
pointer := 2
for pointer < totalLen {
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
// skip "key exchange length"
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
if length != 32 {
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
}
return input[pointer : pointer+length], nil
}
pointer += 2
length := int(u16(input[pointer : pointer+2]))
pointer += 2
_ = input[pointer : pointer+length]
pointer += length
}
return nil, errors.New("x25519 does not exist")
}
// addRecordLayer adds record layer to data
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
// parseClientHello parses everything on top of the TLS layer
// (including the record layer) into ClientHello type
func parseClientHello(data []byte) (ret *ClientHello, err error) {
defer func() {
if r := recover(); r != nil {
err = errors.New("Malformed ClientHello")
}
}()
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
}
peeled := make([]byte, len(data)-5)
copy(peeled, data[5:])
pointer := 0
// Handshake Type
handshakeType := peeled[pointer]
if handshakeType != 0x01 {
return ret, errors.New("Not a ClientHello")
}
pointer += 1
// Length
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
pointer += 3
if length != len(peeled[pointer:]) {
return ret, errors.New("Hello length doesn't match")
}
// Client Version
clientVersion := peeled[pointer : pointer+2]
pointer += 2
// Random
random := peeled[pointer : pointer+32]
pointer += 32
// Session ID
sessionIdLen := int(peeled[pointer])
pointer += 1
sessionId := peeled[pointer : pointer+sessionIdLen]
pointer += sessionIdLen
// Cipher Suites
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
pointer += cipherSuitesLen
// Compression Methods
compressionMethodsLen := int(peeled[pointer])
pointer += 1
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
pointer += compressionMethodsLen
// Extensions
extensionsLen := int(u16(peeled[pointer : pointer+2]))
pointer += 2
extensions, err := parseExtensions(peeled[pointer:])
ret = &ClientHello{
handshakeType,
length,
clientVersion,
random,
sessionIdLen,
sessionId,
cipherSuitesLen,
cipherSuites,
compressionMethodsLen,
compressionMethods,
extensionsLen,
extensions,
}
return
}
func composeServerHello(sessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte) []byte {
var serverHello [11][]byte
serverHello[0] = []byte{0x02} // handshake type
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 118
serverHello[2] = []byte{0x03, 0x03} // server version
serverHello[3] = append(nonce[0:12], encryptedSessionKeyWithTag[0:20]...) // random 32 bytes
serverHello[4] = []byte{0x20} // session id length 32
serverHello[5] = sessionId // session id
serverHello[6] = []byte{0x13, 0x02} // cipher suite TLS_AES_256_GCM_SHA384
serverHello[7] = []byte{0x00} // compression method null
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
keyShare := []byte{0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20}
keyExchange := make([]byte, 32)
copy(keyExchange, encryptedSessionKeyWithTag[20:48])
common.CryptoRandRead(keyExchange[28:32])
serverHello[9] = append(keyShare, keyExchange...)
serverHello[10] = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} // supported versions
var ret []byte
for _, s := range serverHello {
ret = append(ret, s...)
}
return ret
}
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
// together with their respective record layers into one byte slice.
func composeReply(clientHelloSessionId []byte, nonce [12]byte, encryptedSessionKeyWithTag [48]byte, cert []byte) []byte {
TLS12 := []byte{0x03, 0x03}
sh := composeServerHello(clientHelloSessionId, nonce, encryptedSessionKeyWithTag)
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
ret := append(shBytes, ccsBytes...)
ret = append(ret, encryptedCertBytes...)
return ret
}

View File

@ -1,54 +0,0 @@
package server
import (
"bytes"
"encoding/hex"
"testing"
)
func TestParseClientHello(t *testing.T) {
t.Run("good Cloak ClientHello", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, err := parseClientHello(chBytes)
if err != nil {
t.Errorf("Expecting no error, got %v", err)
return
}
if !bytes.Equal(ch.clientVersion, []byte{0x03, 0x03}) {
t.Errorf("expecting client version 0x0303, got %v", ch.clientVersion)
return
}
})
t.Run("Malformed ClientHello", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fb2f21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, err := parseClientHello(chBytes)
if err == nil {
t.Error("expecting Malformed ClientHello, got no error")
return
}
})
t.Run("not Handshake", func(t *testing.T) {
chBytes, _ := hex.DecodeString("ff03010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, err := parseClientHello(chBytes)
if err == nil {
t.Error("not a tls handshake, got no error")
return
}
})
t.Run("wrong TLS record layer version", func(t *testing.T) {
chBytes, _ := hex.DecodeString("16ff010200010001fc03034986187cfaf4c55866a0d9b68f82505fd694a3f0fbf21ca3dcf260baad91d75e20c10e2d2c66f4f9366296678550ed769aa0c41cae7e5f480f59bd929b747ee48d0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00208d7d5a544a72e67adb1bacde46aa147b086f714c073f8335688dc13b2a032986001700414e06fb9a27480a93159f3d6273afebb4d307c4a734d7107d883b6edacb58f7d289a95ad8aaedef1b5f76fe09267a14e6bee2b6db4506b43cf0a410a4645105f79f002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, err := parseClientHello(chBytes)
if err == nil {
t.Error("wrong version, got no error")
return
}
})
t.Run("TLS 1.2", func(t *testing.T) {
chBytes, _ := hex.DecodeString("16030300bd010000b903035d5741ed86719917a932db1dc59a22c7166bf90f5bd693564341d091ffbac5db00002ac02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c013009d009c003d003c0035002f000a0100006600000022002000001d6e61762e736d61727473637265656e2e6d6963726f736f66742e636f6d000500050100000000000a00080006001d00170018000b00020100000d001400120401050102010403050302030202060106030023000000170000ff01000100")
_, err := parseClientHello(chBytes)
if err == nil {
t.Error("wrong version, got no error")
return
}
})
}

View File

@ -1,79 +0,0 @@
package server
import (
"sync"
"github.com/cbeuw/Cloak/internal/server/usermanager"
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
type ActiveUser struct {
panel *userPanel
arrUID [16]byte
valve mux.Valve
bypass bool
sessionsM sync.RWMutex
sessions map[uint32]*mux.Session
}
// CloseSession closes a session and removes its reference from the user
func (u *ActiveUser) CloseSession(sessionID uint32, reason string) {
u.sessionsM.Lock()
sesh, existing := u.sessions[sessionID]
if existing {
delete(u.sessions, sessionID)
sesh.SetTerminalMsg(reason)
sesh.Close()
}
remaining := len(u.sessions)
u.sessionsM.Unlock()
if remaining == 0 {
u.panel.TerminateActiveUser(u, "no session left")
}
}
// GetSession returns the reference to an existing session, or if one such session doesn't exist, it queries
// the UserManager for the authorisation for a new session. If a new session is allowed, it creates this new session
// and returns its reference
func (u *ActiveUser) GetSession(sessionID uint32, config mux.SessionConfig) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock()
defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil {
return sesh, true, nil
} else {
if !u.bypass {
ainfo := usermanager.AuthorisationInfo{NumExistingSessions: len(u.sessions)}
err := u.panel.Manager.AuthoriseNewSession(u.arrUID[:], ainfo)
if err != nil {
return nil, false, err
}
}
config.Valve = u.valve
sesh = mux.MakeSession(sessionID, config)
u.sessions[sessionID] = sesh
return sesh, false, nil
}
}
// closeAllSessions closes all sessions of this active user
func (u *ActiveUser) closeAllSessions(reason string) {
u.sessionsM.Lock()
for sessionID, sesh := range u.sessions {
sesh.SetTerminalMsg(reason)
sesh.Close()
delete(u.sessions, sessionID)
}
u.sessionsM.Unlock()
}
// NumSession returns the number of active sessions
func (u *ActiveUser) NumSession() int {
u.sessionsM.RLock()
defer u.sessionsM.RUnlock()
return len(u.sessions)
}

View File

@ -1,123 +0,0 @@
package server
import (
"crypto/rand"
"encoding/base64"
"io/ioutil"
"os"
"testing"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server/usermanager"
)
func getSeshConfig(unordered bool) mux.SessionConfig {
var sessionKey [32]byte
rand.Read(sessionKey[:])
obfuscator, _ := mux.MakeObfuscator(0x00, sessionKey)
seshConfig := mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: unordered,
}
return seshConfig
}
func TestActiveUser_Bypass(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState)
if err != nil {
t.Fatal("failed to make local manager", err)
}
panel := MakeUserPanel(manager)
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
user, _ := panel.GetBypassUser(UID)
var sesh0 *mux.Session
var existing bool
var sesh1 *mux.Session
// get first session
sesh0, existing, err = user.GetSession(0, getSeshConfig(false))
if err != nil {
t.Fatal(err)
}
if existing {
t.Fatal("get first session: first session returned as existing")
}
if sesh0 == nil {
t.Fatal("get first session: no session returned")
}
// get first session again
seshx, existing, err := user.GetSession(0, mux.SessionConfig{})
if err != nil {
t.Fatal(err)
}
if !existing {
t.Fatal("get first session again: first session get again returned as not existing")
}
if seshx == nil {
t.Fatal("get first session again: no session returned")
}
if seshx != sesh0 {
t.Fatal("returned a different instance")
}
// get second session
sesh1, existing, err = user.GetSession(1, getSeshConfig(false))
if err != nil {
t.Fatal(err)
}
if existing {
t.Fatal("get second session: second session returned as existing")
}
if sesh1 == nil {
t.Fatal("get second session: no session returned")
}
if user.NumSession() != 2 {
t.Fatal("number of session is not 2")
}
user.CloseSession(0, "")
if user.NumSession() != 1 {
t.Fatal("number of session is not 1 after deleting one")
}
if !sesh0.IsClosed() {
t.Fatal("session not closed after deletion")
}
user.closeAllSessions("")
if !sesh1.IsClosed() {
t.Fatal("session not closed after user termination")
}
// get session again after termination
seshy, existing, err := user.GetSession(0, getSeshConfig(false))
if err != nil {
t.Fatal(err)
}
if existing {
t.Fatal("get session again after termination: session returned as existing")
}
if seshy == nil {
t.Fatal("get session again after termination: no session returned")
}
if seshy == sesh0 || seshy == sesh1 {
t.Fatal("get session after termination returned the same instance")
}
user.CloseSession(0, "")
if panel.isActive(user.arrUID[:]) {
t.Fatal("user still active after last session deleted")
}
err = manager.Close()
if err != nil {
t.Fatal("failed to close localmanager", err)
}
}

View File

@ -2,88 +2,63 @@ package server
import (
"bytes"
"crypto"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"time"
"log"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
"github.com/cbeuw/Cloak/internal/util"
ecdh "github.com/cbeuw/go-ecdh"
)
type ClientInfo struct {
UID []byte
SessionId uint32
ProxyMethod string
EncryptionMethod byte
Unordered bool
Transport Transport
}
type authFragments struct {
sharedSecret [32]byte
randPubKey [32]byte
ciphertextWithTag [64]byte
}
const (
UNORDERED_FLAG = 0x01 // 0000 0001
)
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
// decryptClientInfo checks if a the authFragments are valid. It doesn't check if the UID is authorised
func decryptClientInfo(fragments authFragments, serverTime time.Time) (info ClientInfo, err error) {
var plaintext []byte
plaintext, err = common.AESGCMDecrypt(fragments.randPubKey[0:12], fragments.sharedSecret[:], fragments.ciphertextWithTag[:])
// input ticket, return UID
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, uint32, error) {
ec := ecdh.NewCurve25519ECDH()
ephPub, _ := ec.Unmarshal(ticket[0:32])
key, err := ec.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
return nil, 0, err
}
UIDsID := util.AESDecrypt(ticket[0:16], key, ticket[32:68])
sessionID := binary.BigEndian.Uint32(UIDsID[32:36])
return UIDsID[0:32], sessionID, nil
}
info = ClientInfo{
UID: plaintext[0:16],
SessionId: 0,
ProxyMethod: string(bytes.Trim(plaintext[16:28], "\x00")),
EncryptionMethod: plaintext[28],
Unordered: plaintext[41]&UNORDERED_FLAG != 0,
func validateRandom(random []byte, UID []byte, time int64) bool {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(time/(12*60*60)))
rdm := random[0:16]
preHash := make([]byte, 56)
copy(preHash[0:32], UID)
copy(preHash[32:40], t)
copy(preHash[40:56], rdm)
h := sha256.New()
h.Write(preHash)
return bytes.Equal(h.Sum(nil)[0:16], random[16:32])
}
timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37]))
clientTime := time.Unix(timestamp, 0)
if !(clientTime.After(serverTime.Add(-timestampTolerance)) && clientTime.Before(serverTime.Add(timestampTolerance))) {
err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp)
return
func TouchStone(ch *ClientHello, sta *State) (isSS bool, UID []byte, sessionID uint32) {
var random [32]byte
copy(random[:], ch.random)
used := sta.getUsedRandom(random)
if used != 0 {
log.Println("Replay! Duplicate random")
return false, nil, 0
}
info.SessionId = binary.BigEndian.Uint32(plaintext[37:41])
return
sta.putUsedRandom(random)
ticket := ch.extensions[[2]byte{0x00, 0x23}]
if len(ticket) < 64 {
return false, nil, 0
}
var ErrReplay = errors.New("duplicate random")
var ErrBadProxyMethod = errors.New("invalid proxy method")
var ErrBadDecryption = errors.New("decryption/authentication failure")
// AuthFirstPacket checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client
// if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
// the handshake
func AuthFirstPacket(firstPacket []byte, transport Transport, sta *State) (info ClientInfo, finisher Responder, err error) {
fragments, finisher, err := transport.processFirstPacket(firstPacket, sta.StaticPv)
UID, sessionID, err := decryptSessionTicket(sta.staticPv, ticket)
if err != nil {
return
log.Printf("ts: %v\n", err)
return false, nil, 0
}
isSS = validateRandom(ch.random, UID, sta.Now().Unix())
if !isSS {
return false, nil, 0
}
if sta.registerRandom(fragments.randPubKey) {
err = ErrReplay
return
}
info, err = decryptClientInfo(fragments, sta.WorldState.Now().UTC())
if err != nil {
log.Debug(err)
err = fmt.Errorf("%w: %v", ErrBadDecryption, err)
return
}
info.Transport = transport
return
}

View File

@ -1,196 +0,0 @@
package server
import (
"crypto"
"encoding/hex"
"fmt"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
)
func TestDecryptClientInfo(t *testing.T) {
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
p, _ := ecdh.Unmarshal(pvBytes)
staticPv := p.(crypto.PrivateKey)
t.Run("correct time", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSix := time.Unix(1565998966, 0)
cinfo, err := decryptClientInfo(ai, nineSixSix)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
if cinfo.SessionId != 3710878841 {
t.Errorf("expecting session id 3710878841, got %v", cinfo.SessionId)
}
})
t.Run("roughly correct time", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixP50 := time.Unix(1565998966, 0).Add(50)
_, err = decryptClientInfo(ai, nineSixSixP50)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixM50 := time.Unix(1565998966, 0).Add(-50)
_, err = decryptClientInfo(ai, nineSixSixM50)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
})
t.Run("over interval", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixOver := time.Unix(1565998966, 0).Add(timestampTolerance + 10)
_, err = decryptClientInfo(ai, nineSixSixOver)
if err == nil {
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
return
}
})
t.Run("under interval", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
nineSixSixUnder := time.Unix(1565998966, 0).Add(-(timestampTolerance + 10))
_, err = decryptClientInfo(ai, nineSixSixUnder)
if err == nil {
t.Errorf("expecting %v, got %v", ErrTimestampOutOfWindow, err)
return
}
})
t.Run("not cloak psk", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010246010002420303794ae79c6db7a31e67e2ce91b8afcb82995ae79ad1d0dc885f933e4193bf95cd208abd7a70f3b82cc31c02f1c2b94ba74d5222a66695a5cf92a366421d7f5eb9530022fafa130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001d75a5a00000000001e001c0000196c68332e676f6f676c6575736572636f6e74656e742e636f6d00170000ff01000100000a000a0008baba001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029baba000100001d002074bfe93336c364b43cf0879d997b2e11dc97068b86fc90174e0f2bcea1d4ed1c002d00020101002b000b0ababa0304030303020301001b00030200029a9a0001000029010500e000da00d1f6c0918f865390ae3ca33c77f61a1974cb4533456071b214ec018d17dc22845f2f72cf1dba48f9cdc0758803002dda9b964fad5522e82442af7cbbe242241e39233386f2383bce3ced8e16b1ae3f0ef52a706f58e1e6a1bca0cd3b3a2a4c4cb738770b01b56bf3e73c472bf4fb238cab510aa78f8427a3ca99f741aa433f548be460705f43a3abe878cec6ee3158c129406910b93e798e8a7aaffc2e7ff7b8fd872778d3687a0beaa1452fe7ec418070d537344b64d09f6edd053346ff9c9678eef6b8886882aba81d4be11d9df653de35659f93a22ac39399e3ba400021204e22b73261693967a9216fe4a3b004571c53f316309e76671a18d78931b5b072")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
fiveOSix := time.Unix(1565999506, 0)
cinfo, err := decryptClientInfo(ai, fiveOSix)
if err == nil {
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
return
}
})
t.Run("not cloak no psk", func(t *testing.T) {
chBytes, _ := hex.DecodeString("1603010200010001fc0303eae4c204a867390a758fcff3afa5803cac3e07011cf0c9f3befc1267445aabee20fc398df698113617f8161cbcb89534efa892088a6c5e49246534e05f790ea36f00220a0a130113021303c02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a010001910a0a000000000014001200000f63646e2e62697a69626c652e636f6d00170000ff01000100000a000a0008caca001d00170018000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d00140012040308040401050308050501080606010201001200000033002b0029caca000100001d00204c8f1563fb70c261bc0c32c1b568b8d02fab25f4094711e7868b1712751dc754002d00020101002b000b0a2a2a0304030303020301001b00030200026a6a000100001500c9000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
ch, _ := parseClientHello(chBytes)
ai, err := TLS{}.unmarshalClientHello(ch, staticPv)
if err != nil {
t.Errorf("expecting no error, got %v", err)
return
}
sixOneFive := time.Unix(1565999615, 0)
cinfo, err := decryptClientInfo(ai, sixOneFive)
if err == nil {
t.Errorf("not a cloak, got nil error and cinfo %v", cinfo)
return
}
})
}
func TestAuthFirstPacket(t *testing.T) {
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
p, _ := ecdh.Unmarshal(pvBytes)
getNewState := func() *State {
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1565998966, 0)))
sta.StaticPv = p.(crypto.PrivateKey)
sta.ProxyBook["shadowsocks"] = nil
return sta
}
t.Run("TLS correct", func(t *testing.T) {
sta := getNewState()
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
info, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
if err != nil {
t.Errorf("failed to get client info: %v", err)
return
}
if info.SessionId != 3710878841 {
t.Error("failed to get correct session id")
return
}
if info.Transport.(fmt.Stringer).String() != "TLS" {
t.Errorf("wrong transport: %v", info.Transport)
return
}
})
t.Run("TLS correct but replay", func(t *testing.T) {
sta := getNewState()
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
_, _, err := AuthFirstPacket(chBytes, TLS{}, sta)
if err != nil {
t.Error("failed to prepare for the first time")
return
}
_, _, err = AuthFirstPacket(chBytes, TLS{}, sta)
if err != ErrReplay {
t.Errorf("failed to return ErrReplay, got %v instead", err)
return
}
})
t.Run("Websocket correct", func(t *testing.T) {
sta, _ := InitState(RawConfig{}, common.WorldOfTime(time.Unix(1584358419, 0)))
sta.StaticPv = p.(crypto.PrivateKey)
sta.ProxyBook["shadowsocks"] = nil
req := `GET / HTTP/1.1
Host: d2jkinvisak5y9.cloudfront.net:443
User-Agent: Go-http-client/1.1
Connection: Upgrade
Hidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ
Sec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==
Sec-WebSocket-Version: 13
Upgrade: websocket
`
info, _, err := AuthFirstPacket([]byte(req), WebSocket{}, sta)
if err != nil {
t.Errorf("failed to get client info: %v", err)
return
}
if info.Transport.(fmt.Stringer).String() != "WebSocket" {
t.Errorf("wrong transport: %v", info.Transport)
return
}
})
}

View File

@ -1,311 +0,0 @@
package server
import (
"bytes"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/server/usermanager"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
var b64 = base64.StdEncoding.EncodeToString
const firstPacketSize = 3000
func Serve(l net.Listener, sta *State) {
waitDur := [10]time.Duration{
50 * time.Millisecond, 100 * time.Millisecond, 300 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second,
3 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second, 30 * time.Second}
fails := 0
for {
conn, err := l.Accept()
if err != nil {
log.Errorf("%v, retrying", err)
time.Sleep(waitDur[fails])
if fails < 9 {
fails++
}
continue
}
fails = 0
go dispatchConnection(conn, sta)
}
}
func connReadLine(conn net.Conn, buf []byte) (int, error) {
i := 0
for ; i < len(buf); i++ {
_, err := io.ReadFull(conn, buf[i:i+1])
if err != nil {
return i, err
}
if buf[i] == '\n' {
return i + 1, nil
}
}
return i, io.ErrShortBuffer
}
var ErrUnrecognisedProtocol = errors.New("unrecognised protocol")
func readFirstPacket(conn net.Conn, buf []byte, timeout time.Duration) (int, Transport, bool, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
defer conn.SetReadDeadline(time.Time{})
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return 0, nil, false, err
}
// TODO: give the option to match the protocol with port
bufOffset := 1
var transport Transport
switch buf[0] {
case 0x16:
transport = TLS{}
recordLayerLength := 5
i, err := io.ReadFull(conn, buf[bufOffset:recordLayerLength])
bufOffset += i
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
dataLength := int(binary.BigEndian.Uint16(buf[3:5]))
if dataLength+recordLayerLength > len(buf) {
return bufOffset, transport, true, io.ErrShortBuffer
}
i, err = io.ReadFull(conn, buf[recordLayerLength:dataLength+recordLayerLength])
bufOffset += i
if err != nil {
err = fmt.Errorf("read error after connection is established: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
case 0x47:
transport = WebSocket{}
for {
i, err := connReadLine(conn, buf[bufOffset:])
line := buf[bufOffset : bufOffset+i]
bufOffset += i
if err != nil {
if err == io.ErrShortBuffer {
return bufOffset, transport, true, err
} else {
err = fmt.Errorf("error reading first packet: %v", err)
conn.Close()
return bufOffset, transport, false, err
}
}
if bytes.Equal(line, []byte("\r\n")) {
break
}
}
default:
return bufOffset, transport, true, ErrUnrecognisedProtocol
}
return bufOffset, transport, true, nil
}
func dispatchConnection(conn net.Conn, sta *State) {
var err error
buf := make([]byte, firstPacketSize)
i, transport, redirOnErr, err := readFirstPacket(conn, buf, 15*time.Second)
data := buf[:i]
goWeb := func() {
redirPort := sta.RedirPort
if redirPort == "" {
_, redirPort, _ = net.SplitHostPort(conn.LocalAddr().String())
}
webConn, err := sta.RedirDialer.Dial("tcp", net.JoinHostPort(sta.RedirHost.String(), redirPort))
if err != nil {
log.Errorf("Making connection to redirection server: %v", err)
return
}
_, err = webConn.Write(data)
if err != nil {
log.Error("Failed to send first packet to redirection server", err)
return
}
go common.Copy(webConn, conn)
go common.Copy(conn, webConn)
}
if err != nil {
log.WithField("remoteAddr", conn.RemoteAddr()).
Warnf("error reading first packet: %v", err)
if redirOnErr {
goWeb()
} else {
conn.Close()
}
return
}
ci, finishHandshake, err := AuthFirstPacket(data, transport, sta)
if err != nil {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Warn(err)
goWeb()
return
}
var sessionKey [32]byte
common.RandRead(sta.WorldState.Rand, sessionKey[:])
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
if err != nil {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Error(err)
goWeb()
return
}
seshConfig := mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
Unordered: ci.Unordered,
MsgOnWireSizeLimit: appDataMaxLength,
}
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if len(sta.AdminUID) != 0 && bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
sesh := mux.MakeSession(0, seshConfig)
preparedConn, err := finishHandshake(conn, sessionKey, sta.WorldState.Rand)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
sesh.AddConnection(preparedConn)
//TODO: Router could be nil in cnc mode
log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session")
err = http.Serve(sesh, usermanager.APIRouterOf(sta.Panel.Manager))
// http.Serve never returns with non-nil error
log.Error(err)
return
}
if _, ok := sta.ProxyBook[ci.ProxyMethod]; !ok {
log.WithFields(log.Fields{
"remoteAddr": conn.RemoteAddr(),
"UID": b64(ci.UID),
"sessionId": ci.SessionId,
"proxyMethod": ci.ProxyMethod,
"encryptionMethod": ci.EncryptionMethod,
}).Error(ErrBadProxyMethod)
goWeb()
return
}
var user *ActiveUser
if sta.IsBypass(ci.UID) {
user, err = sta.Panel.GetBypassUser(ci.UID)
} else {
user, err = sta.Panel.GetUser(ci.UID)
}
if err != nil {
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"remoteAddr": conn.RemoteAddr(),
"error": err,
}).Warn("+1 unauthorised UID")
goWeb()
return
}
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
if err != nil {
user.CloseSession(ci.SessionId, "")
log.Error(err)
return
}
preparedConn, err := finishHandshake(conn, sesh.GetSessionKey(), sta.WorldState.Rand)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
sesh.AddConnection(preparedConn)
if !existing {
// if the session was newly made, we serve connections from the session streams to the proxy server
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
}).Info("New session")
serveSession(sesh, ci, user, sta)
}
}
func serveSession(sesh *mux.Session, ci ClientInfo, user *ActiveUser, sta *State) error {
for {
newStream, err := sesh.Accept()
if err != nil {
if err == mux.ErrBrokenSession {
log.WithFields(log.Fields{
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
"reason": sesh.TerminalMsg(),
}).Info("Session closed")
user.CloseSession(ci.SessionId, "")
return nil
} else {
log.Errorf("unhandled error on session.Accept(): %v", err)
continue
}
}
proxyAddr := sta.ProxyBook[ci.ProxyMethod]
localConn, err := sta.ProxyDialer.Dial(proxyAddr.Network(), proxyAddr.String())
if err != nil {
log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
user.CloseSession(ci.SessionId, "Failed to connect to proxy server")
return err
}
log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod)
go func() {
if _, err := common.Copy(localConn, newStream); err != nil {
log.Tracef("copying stream to proxy server: %v", err)
}
}()
go func() {
if _, err := common.Copy(newStream, localConn); err != nil {
log.Tracef("copying proxy server to stream: %v", err)
}
}()
}
}

View File

@ -1,211 +0,0 @@
package server
import (
"encoding/hex"
"io"
"net"
"testing"
"time"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
)
type rfpReturnValue struct {
n int
transport Transport
redirOnErr bool
err error
}
const timeout = 500 * time.Millisecond
func TestReadFirstPacket(t *testing.T) {
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue) {
ret := rfpReturnValue{}
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, timeout)
retChan <- ret
}
t.Run("Good TLS", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first)
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("Good TLS but buf too small", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 10)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first)
ret := <-retChan
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
assert.Equal(t, first[:ret.n], buf[:ret.n])
})
t.Run("Incomplete timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("160301")
local.Write(first)
select {
case ret := <-retChan:
assert.Equal(t, len(first), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Incomplete payload timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("16030101010000")
local.Write(first)
select {
case ret := <-retChan:
assert.Equal(t, len(first), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Good TLS staggered", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
local.Write(first[:100])
time.Sleep(timeout / 2)
local.Write(first[100:])
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("TLS bad recordlayer length", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
first, _ := hex.DecodeString("160301ffff")
local.Write(first)
ret := <-retChan
assert.Equal(t, len(first), ret.n)
assert.Equal(t, first, buf[:ret.n])
assert.IsType(t, TLS{}, ret.transport)
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
})
t.Run("Good WebSocket", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req)
ret := <-retChan
assert.Equal(t, len(req), ret.n)
assert.Equal(t, req, buf[:ret.n])
assert.IsType(t, WebSocket{}, ret.transport)
assert.NoError(t, ret.err)
})
t.Run("Good WebSocket but buf too small", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 10)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req)
ret := <-retChan
assert.Equal(t, io.ErrShortBuffer, ret.err)
assert.True(t, ret.redirOnErr)
assert.Equal(t, req[:ret.n], buf[:ret.n])
})
t.Run("Incomplete WebSocket timeout", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET /"
req := []byte(reqStr)
local.Write(req)
select {
case ret := <-retChan:
assert.Equal(t, len(req), ret.n)
assert.False(t, ret.redirOnErr)
assert.Error(t, ret.err)
case <-time.After(2 * timeout):
assert.Fail(t, "readFirstPacket should have timed out")
}
})
t.Run("Staggered WebSocket", func(t *testing.T) {
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue)
go rfp(remote, buf, retChan)
reqStr := "GET / HTTP/1.1\r\nHost: d2jkinvisak5y9.cloudfront.net:443\r\nUser-Agent: Go-http-client/1.1\r\nConnection: Upgrade\r\nHidden: oJxeEwfDWg5k5Jbl8ttZD1sc0fHp8VjEtXHsqEoSrnaLRe/M+KGXkOzpc/2fRRg9Vk+wIWRsfv8IpoBPLbqO+ZfGsPXTjUJGiI9BqxrcJfkxncXA7FAHGpTc84tzBtZZ\r\nSec-WebSocket-Key: lJYh7X8DRXW1U0h9WKwVMA==\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\n\r\n"
req := []byte(reqStr)
local.Write(req[:100])
time.Sleep(timeout / 2)
local.Write(req[100:])
ret := <-retChan
assert.Equal(t, len(req), ret.n)
assert.Equal(t, req, buf[:ret.n])
assert.IsType(t, WebSocket{}, ret.transport)
assert.NoError(t, ret.err)
})
}

View File

@ -1,64 +0,0 @@
//go:build gofuzz
// +build gofuzz
package server
import (
"errors"
"net"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/connutil"
)
type rfpReturnValue_fuzz struct {
n int
transport Transport
redirOnErr bool
err error
}
func Fuzz(data []byte) int {
var bypassUID [16]byte
var pv [32]byte
sta := &State{
BypassUID: map[[16]byte]struct{}{
bypassUID: {},
},
ProxyBook: map[string]net.Addr{
"shadowsocks": nil,
},
UsedRandom: map[[32]byte]int64{},
StaticPv: &pv,
WorldState: common.RealWorldState,
}
rfp := func(conn net.Conn, buf []byte, retChan chan<- rfpReturnValue_fuzz) {
ret := rfpReturnValue_fuzz{}
ret.n, ret.transport, ret.redirOnErr, ret.err = readFirstPacket(conn, buf, 500*time.Millisecond)
retChan <- ret
}
local, remote := connutil.AsyncPipe()
buf := make([]byte, 1500)
retChan := make(chan rfpReturnValue_fuzz)
go rfp(remote, buf, retChan)
local.Write(data)
ret := <-retChan
if ret.err != nil {
return 1
}
_, _, err := AuthFirstPacket(buf[:ret.n], ret.transport, sta)
if !errors.Is(err, ErrReplay) && !errors.Is(err, ErrBadDecryption) {
return 1
}
return 0
}

View File

@ -2,232 +2,151 @@ package server
import (
"crypto"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"strings"
"sync"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/server/usermanager"
)
type RawConfig struct {
ProxyBook map[string][]string
BindAddr []string
BypassUID [][]byte
RedirAddr string
PrivateKey []byte
AdminUID []byte
type rawConfig struct {
WebServerAddr string
PrivateKey string
AdminUID string
DatabasePath string
KeepAlive int
CncMode bool
BackupDirPath string
}
type stateManager interface {
ParseConfig(string) error
SetAESKey(string)
PutUsedRandom([32]byte)
}
// State type stores the global state of the program
type State struct {
ProxyBook map[string]net.Addr
ProxyDialer common.Dialer
SS_LOCAL_HOST string
SS_LOCAL_PORT string
SS_REMOTE_HOST string
SS_REMOTE_PORT string
WorldState common.WorldState
Now func() time.Time
AdminUID []byte
BypassUID map[[16]byte]struct{}
StaticPv crypto.PrivateKey
// TODO: this doesn't have to be a net.Addr; resolution is done in Dial automatically
RedirHost net.Addr
RedirPort string
RedirDialer common.Dialer
staticPv crypto.PrivateKey
Userpanel *usermanager.Userpanel
usedRandomM sync.RWMutex
UsedRandom map[[32]byte]int64
usedRandom map[[32]byte]int
Panel *userPanel
WebServerAddr string
}
func parseRedirAddr(redirAddr string) (net.Addr, string, error) {
var host string
var port string
colonSep := strings.Split(redirAddr, ":")
if len(colonSep) > 1 {
if len(colonSep) == 2 {
// domain or ipv4 with port
host = colonSep[0]
port = colonSep[1]
} else {
if strings.Contains(redirAddr, "[") {
// ipv6 with port
port = colonSep[len(colonSep)-1]
host = strings.TrimSuffix(redirAddr, "]:"+port)
host = strings.TrimPrefix(host, "[")
} else {
// ipv6 without port
host = redirAddr
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) (*State, error) {
ret := &State{
SS_LOCAL_HOST: localHost,
SS_LOCAL_PORT: localPort,
SS_REMOTE_HOST: remoteHost,
SS_REMOTE_PORT: remotePort,
Now: nowFunc,
}
}
} else {
// domain or ipv4 without port
host = redirAddr
ret.usedRandom = make(map[[32]byte]int)
return ret, nil
}
redirHost, err := net.ResolveIPAddr("ip", host)
// semi-colon separated value.
func ssvToJson(ssv string) (ret []byte) {
unescape := func(s string) string {
r := strings.Replace(s, `\\`, `\`, -1)
r = strings.Replace(r, `\=`, `=`, -1)
r = strings.Replace(r, `\;`, `;`, -1)
return r
}
lines := strings.Split(unescape(ssv), ";")
ret = []byte("{")
for _, ln := range lines {
if ln == "" {
break
}
sp := strings.SplitN(ln, "=", 2)
key := sp[0]
value := sp[1]
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
}
ret = ret[:len(ret)-1] // remove the last comma
ret = append(ret, '}')
return ret
}
// ParseConfig parses the config (either a path to json or in-line ssv config) into a State variable
func (sta *State) ParseConfig(conf string) (err error) {
var content []byte
if strings.Contains(conf, ";") && strings.Contains(conf, "=") {
content = ssvToJson(conf)
} else {
content, err = ioutil.ReadFile(conf)
if err != nil {
return nil, "", fmt.Errorf("unable to resolve RedirAddr: %v. ", err)
return err
}
return redirHost, port, nil
}
func parseProxyBook(bookEntries map[string][]string) (map[string]net.Addr, error) {
proxyBook := map[string]net.Addr{}
for name, pair := range bookEntries {
name = strings.ToLower(name)
if len(pair) != 2 {
return nil, fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair)
}
network := strings.ToLower(pair[0])
switch network {
case "tcp":
addr, err := net.ResolveTCPAddr("tcp", pair[1])
var preParse rawConfig
err = json.Unmarshal(content, &preParse)
if err != nil {
return nil, err
return errors.New("Failed to unmarshal: " + err.Error())
}
proxyBook[name] = addr
continue
case "udp":
addr, err := net.ResolveUDPAddr("udp", pair[1])
up, err := usermanager.MakeUserpanel(preParse.DatabasePath, preParse.BackupDirPath)
if err != nil {
return nil, err
}
proxyBook[name] = addr
continue
}
}
return proxyBook, nil
return err
}
sta.Userpanel = up
// ParseConfig reads the config file or semicolon-separated options and parse them into a RawConfig
func ParseConfig(conf string) (raw RawConfig, err error) {
content, errPath := ioutil.ReadFile(conf)
if errPath != nil {
errJson := json.Unmarshal(content, &raw)
if errJson != nil {
err = fmt.Errorf("failed to read/unmarshal configuration, path is invalid or %v", errJson)
return
}
} else {
errJson := json.Unmarshal(content, &raw)
if errJson != nil {
err = fmt.Errorf("failed to read configuration file: %v", errJson)
return
}
}
if raw.ProxyBook == nil {
raw.ProxyBook = make(map[string][]string)
}
return
}
sta.WebServerAddr = preParse.WebServerAddr
// InitState process the RawConfig and initialises a server State accordingly
func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, err error) {
sta = &State{
BypassUID: make(map[[16]byte]struct{}),
ProxyBook: map[string]net.Addr{},
UsedRandom: map[[32]byte]int64{},
RedirDialer: &net.Dialer{},
WorldState: worldState,
}
if preParse.CncMode {
err = errors.New("command & control mode not implemented")
return
} else {
var manager usermanager.UserManager
if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" {
manager = &usermanager.Voidmanager{}
} else {
manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
pvBytes, err := base64.StdEncoding.DecodeString(preParse.PrivateKey)
if err != nil {
return sta, err
}
}
sta.Panel = MakeUserPanel(manager)
}
if preParse.KeepAlive <= 0 {
sta.ProxyDialer = &net.Dialer{KeepAlive: -1}
} else {
sta.ProxyDialer = &net.Dialer{KeepAlive: time.Duration(preParse.KeepAlive) * time.Second}
}
sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr)
if err != nil {
err = fmt.Errorf("unable to parse RedirAddr: %v", err)
return
}
sta.ProxyBook, err = parseProxyBook(preParse.ProxyBook)
if err != nil {
err = fmt.Errorf("unable to parse ProxyBook: %v", err)
return
}
if len(preParse.PrivateKey) == 0 {
err = fmt.Errorf("must have a valid private key. Run `ck-server -key` to generate one")
return
return errors.New("Failed to decode private key: " + err.Error())
}
var pv [32]byte
copy(pv[:], preParse.PrivateKey)
sta.StaticPv = &pv
copy(pv[:], pvBytes)
sta.staticPv = &pv
sta.AdminUID = preParse.AdminUID
var arrUID [16]byte
for _, UID := range preParse.BypassUID {
copy(arrUID[:], UID)
sta.BypassUID[arrUID] = struct{}{}
adminUID, err := base64.StdEncoding.DecodeString(preParse.AdminUID)
if err != nil {
return errors.New("Failed to decode AdminUID: " + err.Error())
}
if len(sta.AdminUID) != 0 {
copy(arrUID[:], sta.AdminUID)
sta.BypassUID[arrUID] = struct{}{}
sta.AdminUID = adminUID
return nil
}
go sta.UsedRandomCleaner()
return sta, nil
func (sta *State) getUsedRandom(random [32]byte) int {
sta.usedRandomM.Lock()
defer sta.usedRandomM.Unlock()
return sta.usedRandom[random]
}
// IsBypass checks if a UID is a bypass user
func (sta *State) IsBypass(UID []byte) bool {
var arrUID [16]byte
copy(arrUID[:], UID)
_, exist := sta.BypassUID[arrUID]
return exist
// PutUsedRandom adds a random field into map usedRandom
func (sta *State) putUsedRandom(random [32]byte) {
sta.usedRandomM.Lock()
sta.usedRandom[random] = int(sta.Now().Unix())
sta.usedRandomM.Unlock()
}
const timestampTolerance = 180 * time.Second
const replayCacheAgeLimit = 12 * time.Hour
// UsedRandomCleaner clears the cache of used random fields every replayCacheAgeLimit
// UsedRandomCleaner clears the cache of used random fields every 12 hours
func (sta *State) UsedRandomCleaner() {
for {
time.Sleep(replayCacheAgeLimit)
time.Sleep(12 * time.Hour)
now := int(sta.Now().Unix())
sta.usedRandomM.Lock()
for key, t := range sta.UsedRandom {
if time.Unix(t, 0).Before(sta.WorldState.Now().Add(timestampTolerance)) {
delete(sta.UsedRandom, key)
for key, t := range sta.usedRandom {
if now-t > 12*3600 {
delete(sta.usedRandom, key)
}
}
sta.usedRandomM.Unlock()
}
}
func (sta *State) registerRandom(r [32]byte) bool {
sta.usedRandomM.Lock()
_, used := sta.UsedRandom[r]
sta.UsedRandom[r] = sta.WorldState.Now().Unix()
sta.usedRandomM.Unlock()
return used
}

View File

@ -1,126 +0,0 @@
package server
import (
"net"
"testing"
)
func TestParseRedirAddr(t *testing.T) {
t.Run("ipv4 without port", func(t *testing.T) {
ipv4noPort := "1.2.3.4"
host, port, err := parseRedirAddr(ipv4noPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv4noPort, err)
return
}
if host.String() != "1.2.3.4" {
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("ipv4 with port", func(t *testing.T) {
ipv4wPort := "1.2.3.4:1234"
host, port, err := parseRedirAddr(ipv4wPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv4wPort, err)
return
}
if host.String() != "1.2.3.4" {
t.Errorf("expected %v got %v", "1.2.3.4", host.String())
}
if port != "1234" {
t.Errorf("wrong port: expected %v, got %v", "1234", port)
}
})
t.Run("domain without port", func(t *testing.T) {
domainNoPort := "example.com"
host, port, err := parseRedirAddr(domainNoPort)
if err != nil {
t.Errorf("parsing %v error: %v", domainNoPort, err)
return
}
expIPs, err := net.LookupIP("example.com")
if err != nil {
t.Errorf("tester error: cannot resolve example.com: %v", err)
return
}
contain := false
for _, expIP := range expIPs {
if expIP.String() == host.String() {
contain = true
}
}
if !contain {
t.Errorf("expected one of %v got %v", expIPs, host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("domain with port", func(t *testing.T) {
domainWPort := "example.com:80"
host, port, err := parseRedirAddr(domainWPort)
if err != nil {
t.Errorf("parsing %v error: %v", domainWPort, err)
return
}
expIPs, err := net.LookupIP("example.com")
if err != nil {
t.Errorf("tester error: cannot resolve example.com: %v", err)
return
}
contain := false
for _, expIP := range expIPs {
if expIP.String() == host.String() {
contain = true
}
}
if !contain {
t.Errorf("expected one of %v got %v", expIPs, host.String())
}
if port != "80" {
t.Errorf("wrong port: expected %v, got %v", "80", port)
}
})
t.Run("ipv6 without port", func(t *testing.T) {
ipv6noPort := "a:b:c:d::"
host, port, err := parseRedirAddr(ipv6noPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv6noPort, err)
return
}
if host.String() != "a:b:c:d::" {
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
}
if port != "" {
t.Errorf("port not empty when there is no port")
}
})
t.Run("ipv6 with port", func(t *testing.T) {
ipv6wPort := "[a:b:c:d::]:80"
host, port, err := parseRedirAddr(ipv6wPort)
if err != nil {
t.Errorf("parsing %v error: %v", ipv6wPort, err)
return
}
if host.String() != "a:b:c:d::" {
t.Errorf("expected %v got %v", "a:b:c:d::", host.String())
}
if port != "80" {
t.Errorf("wrong port: expected %v, got %v", "80", port)
}
})
}

View File

@ -1,16 +0,0 @@
package server
import (
"crypto"
"errors"
"io"
"net"
)
type Responder = func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error)
type Transport interface {
processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error)
}
var ErrInvalidPubKey = errors.New("public key has invalid format")
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")

View File

@ -1,152 +0,0 @@
swagger: '2.0'
info:
description: |
This is the API of Cloak server
version: 0.0.2
title: Cloak Server
contact:
email: cbeuw.andy@gmail.com
license:
name: GPLv3
url: https://www.gnu.org/licenses/gpl-3.0.en.html
# host: petstore.swagger.io
# basePath: /v2
tags:
- name: users
description: Operations related to user controls by admin
# schemes:
# - http
paths:
/admin/users:
get:
tags:
- users
summary: Show all users
description: Returns an array of all UserInfo
operationId: listAllUsers
produces:
- application/json
responses:
200:
description: successful operation
schema:
type: array
items:
$ref: '#/definitions/UserInfo'
500:
description: internal error
/admin/users/{UID}:
get:
tags:
- users
summary: Show userinfo by UID
description: Returns a UserInfo object
operationId: getUserInfo
produces:
- application/json
parameters:
- name: UID
in: path
description: UID of the user
required: true
type: string
format: byte
responses:
200:
description: successful operation
schema:
$ref: '#/definitions/UserInfo'
400:
description: bad request
404:
description: User not found
500:
description: internal error
post:
tags:
- users
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
operationId: writeUserInfo
consumes:
- application/json
produces:
- application/json
parameters:
- name: UID
in: path
description: UID of the user
required: true
type: string
format: byte
- name: UserInfo
in: body
description: New userinfo
required: true
schema:
type: array
items:
$ref: '#/definitions/UserInfo'
responses:
201:
description: successful operation
400:
description: bad request
500:
description: internal error
delete:
tags:
- users
summary: Deletes a user
operationId: deleteUser
produces:
- application/json
parameters:
- name: UID
in: path
description: UID of the user to be deleted
required: true
type: string
format: byte
responses:
200:
description: successful operation
400:
description: bad request
404:
description: User not found
500:
description: internal error
definitions:
UserInfo:
type: object
properties:
UID:
type: string
format: byte
SessionsCap:
type: integer
format: int32
UpRate:
type: integer
format: int64
DownRate:
type: integer
format: int64
UpCredit:
type: integer
format: int64
DownCredit:
type: integer
format: int64
ExpiryTime:
type: integer
format: int64
externalDocs:
description: Find out more about Swagger
url: http://swagger.io
# Added by API Auto Mocking Plugin
host: 127.0.0.1:8080
basePath: /
schemes:
- http

View File

@ -1,129 +0,0 @@
package usermanager
import (
"bytes"
"encoding/base64"
"encoding/json"
"net/http"
gmux "github.com/gorilla/mux"
)
type APIRouter struct {
*gmux.Router
manager UserManager
}
func APIRouterOf(manager UserManager) *APIRouter {
ret := &APIRouter{
manager: manager,
}
ret.registerMux()
return ret
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
next.ServeHTTP(w, r)
})
}
func (ar *APIRouter) registerMux() {
ar.Router = gmux.NewRouter()
ar.HandleFunc("/admin/users", ar.listAllUsersHlr).Methods("GET")
ar.HandleFunc("/admin/users/{UID}", ar.getUserInfoHlr).Methods("GET")
ar.HandleFunc("/admin/users/{UID}", ar.writeUserInfoHlr).Methods("POST")
ar.HandleFunc("/admin/users/{UID}", ar.deleteUserHlr).Methods("DELETE")
ar.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS")
})
ar.Use(corsMiddleware)
}
func (ar *APIRouter) listAllUsersHlr(w http.ResponseWriter, r *http.Request) {
infos, err := ar.manager.ListAllUsers()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
resp, err := json.Marshal(infos)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (ar *APIRouter) getUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
uinfo, err := ar.manager.GetUserInfo(UID)
if err == ErrUserNotFound {
http.Error(w, ErrUserNotFound.Error(), http.StatusNotFound)
return
}
resp, err := json.Marshal(uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(resp)
}
func (ar *APIRouter) writeUserInfoHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var uinfo UserInfo
err = json.NewDecoder(r.Body).Decode(&uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !bytes.Equal(UID, uinfo.UID) {
http.Error(w, "UID mismatch", http.StatusBadRequest)
}
err = ar.manager.WriteUserInfo(uinfo)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusCreated)
}
func (ar *APIRouter) deleteUserHlr(w http.ResponseWriter, r *http.Request) {
b64UID := gmux.Vars(r)["UID"]
if b64UID == "" {
http.Error(w, "UID cannot be empty", http.StatusBadRequest)
return
}
UID, err := base64.URLEncoding.DecodeString(b64UID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = ar.manager.DeleteUser(UID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
w.WriteHeader(http.StatusOK)
}

View File

@ -1,214 +0,0 @@
package usermanager
import (
"bytes"
"encoding/base64"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
var mockUIDb64 = base64.URLEncoding.EncodeToString(mockUID)
func makeRouter(t *testing.T) (router *APIRouter, cleaner func()) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
cleaner = func() { os.Remove(tmpDB.Name()) }
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
router = APIRouterOf(mgr)
return router, cleaner
}
func TestWriteUserInfoHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
marshalled, err := json.Marshal(mockUserInfo)
if err != nil {
t.Fatal(err)
}
t.Run("ok", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
})
t.Run("partial update", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
assert.NoError(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
partialUserInfo := UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
}
partialMarshalled, _ := json.Marshal(partialUserInfo)
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
var got UserInfo
err = json.Unmarshal(rr.Body.Bytes(), &got)
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = partialUserInfo.SessionsCap
assert.EqualValues(t, expected, got)
})
t.Run("empty parameter", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusMethodNotAllowed, rr.Code, "response body: %v", rr.Body)
})
t.Run("UID mismatch", func(t *testing.T) {
badMock := mockUserInfo
badMock.UID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0}
badMarshal, err := json.Marshal(badMock)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(badMarshal))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
t.Run("garbage data", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer([]byte(`{"{{'{;;}}}1`)))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
t.Run("not base64", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+"defonotbase64", bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusBadRequest, rr.Code, "response body: %v", rr.Body)
})
}
func addUser(t *testing.T, router *APIRouter, user UserInfo) {
marshalled, err := json.Marshal(user)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("POST", "/admin/users/"+base64.URLEncoding.EncodeToString(user.UID), bytes.NewBuffer(marshalled))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
}
func TestGetUserInfoHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
t.Run("empty parameter", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/", nil)
})
t.Run("non-existent", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
})
t.Run("not base64", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+"defonotbase64", nil)
})
t.Run("ok", func(t *testing.T) {
addUser(t, router, mockUserInfo)
var got UserInfo
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)), &got)
if err != nil {
t.Fatal(err)
}
assert.EqualValues(t, mockUserInfo, got)
})
}
func TestDeleteUserHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
t.Run("non-existent", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+base64.URLEncoding.EncodeToString([]byte("adsf")), nil)
})
t.Run("not base64", func(t *testing.T) {
assert.HTTPError(t, router.ServeHTTP, "DELETE", "/admin/users/"+"defonotbase64", nil)
})
t.Run("ok", func(t *testing.T) {
addUser(t, router, mockUserInfo)
assert.HTTPSuccess(t, router.ServeHTTP, "DELETE", "/admin/users/"+mockUIDb64, nil)
assert.HTTPError(t, router.ServeHTTP, "GET", "/admin/users/"+mockUIDb64, nil)
})
}
func TestListAllUsersHlr(t *testing.T) {
router, cleaner := makeRouter(t)
defer cleaner()
user1 := mockUserInfo
addUser(t, router, user1)
user2 := mockUserInfo
user2.UID = []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
addUser(t, router, user2)
expected := []UserInfo{user1, user2}
var got []UserInfo
err := json.Unmarshal([]byte(assert.HTTPBody(router.ServeHTTP, "GET", "/admin/users", nil)), &got)
if err != nil {
t.Fatal(err)
}
assert.True(t, assert.Subset(t, got, expected), assert.Subset(t, expected, got))
}

View File

@ -0,0 +1,212 @@
package usermanager
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"errors"
"log"
)
// FIXME: sanity checks. The server may panic due to user input
// TODO: manual backup
/*
0 reserved
1 listActiveUsers none []uids
2 listAllUsers none []userinfo
3 getUserInfo uid userinfo
4 addNewUser userinfo ok
5 delUser uid ok
6 syncMemFromDB uid ok
7 setSessionsCap uid cap ok
8 setUpRate uid rate ok
9 setDownRate uid rate ok
10 setUpCredit uid credit ok
11 setDownCredit uid credit ok
12 setExpiryTime uid time ok
13 addUpCredit uid delta ok
14 addDownCredit uid delta ok
*/
type controller struct {
*Userpanel
adminUID []byte
}
func (up *Userpanel) MakeController(adminUID []byte) *controller {
return &controller{up, adminUID}
}
var errInvalidArgument = errors.New("Invalid argument format")
func (c *controller) HandleRequest(req []byte) (resp []byte, err error) {
check := func(err error) []byte {
if err != nil {
return c.respond([]byte(err.Error()))
} else {
return c.respond([]byte("ok"))
}
}
plain, err := c.checkAndDecrypt(req)
if err == ErrInvalidMac {
log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\n raw request:\n%x\ndecrypted msg:\n%x", req, plain)
return nil, err
} else if err != nil {
log.Println(err)
return c.respond([]byte(err.Error())), nil
}
typ := plain[0]
var arg []byte
if len(plain) > 1 {
arg = plain[1:]
}
switch typ {
case 1:
UIDs := c.listActiveUsers()
resp, _ = json.Marshal(UIDs)
resp = c.respond(resp)
case 2:
uinfos := c.listAllUsers()
resp, _ = json.Marshal(uinfos)
resp = c.respond(resp)
case 3:
uinfo, err := c.getUserInfo(arg)
if err != nil {
resp = c.respond([]byte(err.Error()))
break
}
resp, _ = json.Marshal(uinfo)
resp = c.respond(resp)
case 4:
var uinfo UserInfo
err = json.Unmarshal(arg, &uinfo)
if err != nil {
resp = c.respond([]byte(err.Error()))
break
}
err = c.addNewUser(uinfo)
resp = check(err)
case 5:
err = c.delUser(arg)
resp = check(err)
case 6:
err = c.syncMemFromDB(arg)
resp = check(err)
case 7:
if len(arg) < 36 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setSessionsCap(arg[0:32], Uint32(arg[32:36]))
resp = check(err)
case 8:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpRate(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 9:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownRate(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 10:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 11:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 12:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setExpiryTime(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 13:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 14:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
default:
return c.respond([]byte("Unsupported action")), nil
}
return
}
var ErrInvalidMac = errors.New("Mac mismatch")
var errMsgTooShort = errors.New("Message length is less than 54")
// protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes]
func (c *controller) respond(resp []byte) []byte {
respLen := len(resp)
buf := make([]byte, 5+16+respLen+32)
buf[0] = 0x17
buf[1] = 0x03
buf[2] = 0x03
PutUint16(buf[3:5], uint16(16+respLen+32))
rand.Read(buf[5:21]) //iv
copy(buf[21:], resp)
block, _ := aes.NewCipher(c.adminUID[0:16])
stream := cipher.NewCTR(block, buf[5:21])
stream.XORKeyStream(buf[21:21+respLen], buf[21:21+respLen])
mac := hmac.New(sha256.New, c.adminUID[16:32])
mac.Write(buf[5 : 21+respLen])
copy(buf[21+respLen:], mac.Sum(nil))
return buf
}
func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) {
if len(data) < 54 {
return nil, errMsgTooShort
}
macIndex := len(data) - 32
mac := hmac.New(sha256.New, c.adminUID[16:32])
mac.Write(data[5:macIndex])
expected := mac.Sum(nil)
if !hmac.Equal(data[macIndex:], expected) {
return nil, ErrInvalidMac
}
iv := data[5:21]
ret := data[21:macIndex]
block, _ := aes.NewCipher(c.adminUID[0:16])
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ret, ret)
return ret, nil
}

View File

@ -1,269 +0,0 @@
package usermanager
import (
"encoding/binary"
"github.com/cbeuw/Cloak/internal/common"
log "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
var u32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64
func i64ToB(value int64) []byte {
oct := make([]byte, 8)
binary.BigEndian.PutUint64(oct, uint64(value))
return oct
}
func i32ToB(value int32) []byte {
nib := make([]byte, 4)
binary.BigEndian.PutUint32(nib, uint32(value))
return nib
}
// localManager is responsible for managing the local user database
type localManager struct {
db *bolt.DB
world common.WorldState
}
func MakeLocalManager(dbPath string, worldState common.WorldState) (*localManager, error) {
db, err := bolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
ret := &localManager{
db: db,
world: worldState,
}
return ret, nil
}
// Authenticate user returns err==nil along with the users' up and down bandwidths if the UID is allowed to connect
// More specifically it checks that the user exists, that it has positive credit and that it hasn't expired
func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) {
var upRate, downRate, upCredit, downCredit, expiryTime int64
err := manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(UID)
if bucket == nil {
return ErrUserNotFound
}
upRate = int64(u64(bucket.Get([]byte("UpRate"))))
downRate = int64(u64(bucket.Get([]byte("DownRate"))))
upCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return 0, 0, err
}
if upCredit <= 0 {
return 0, 0, ErrNoUpCredit
}
if downCredit <= 0 {
return 0, 0, ErrNoDownCredit
}
if expiryTime < manager.world.Now().Unix() {
return 0, 0, ErrUserExpired
}
return upRate, downRate, nil
}
// AuthoriseNewSession returns err==nil when the user is allowed to make a new session
// More specifically it checks that the user exists, has credit, hasn't expired and hasn't reached sessionsCap
func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo AuthorisationInfo) error {
var arrUID [16]byte
copy(arrUID[:], UID)
var sessionsCap int
var upCredit, downCredit, expiryTime int64
err := manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(arrUID[:])
if bucket == nil {
return ErrUserNotFound
}
sessionsCap = int(u32(bucket.Get([]byte("SessionsCap"))))
upCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return err
}
if upCredit <= 0 {
return ErrNoUpCredit
}
if downCredit <= 0 {
return ErrNoDownCredit
}
if expiryTime < manager.world.Now().Unix() {
return ErrUserExpired
}
if ainfo.NumExistingSessions >= sessionsCap {
return ErrSessionsCapReached
}
return nil
}
// UploadStatus gets StatusUpdates representing the recent status of each user, and update them in the database
// it returns a slice of StatusResponse, which represents actions need to be taken for specific users.
// If no action is needed, there won't be a StatusResponse entry for that user
func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusResponse, error) {
var responses []StatusResponse
if len(uploads) == 0 {
return responses, nil
}
err := manager.db.Update(func(tx *bolt.Tx) error {
for _, status := range uploads {
var resp StatusResponse
bucket := tx.Bucket(status.UID)
if bucket == nil {
resp = StatusResponse{
status.UID,
TERMINATE,
"User no longer exists",
}
responses = append(responses, resp)
continue
}
oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
newUp := oldUp - status.UpUsage
if newUp <= 0 {
resp = StatusResponse{
status.UID,
TERMINATE,
"No upload credit left",
}
responses = append(responses, resp)
}
err := bucket.Put([]byte("UpCredit"), i64ToB(newUp))
if err != nil {
log.Error(err)
}
oldDown := int64(u64(bucket.Get([]byte("DownCredit"))))
newDown := oldDown - status.DownUsage
if newDown <= 0 {
resp = StatusResponse{
status.UID,
TERMINATE,
"No download credit left",
}
responses = append(responses, resp)
}
err = bucket.Put([]byte("DownCredit"), i64ToB(newDown))
if err != nil {
log.Error(err)
}
expiry := int64(u64(bucket.Get([]byte("ExpiryTime"))))
if manager.world.Now().Unix() > expiry {
resp = StatusResponse{
status.UID,
TERMINATE,
"User has expired",
}
responses = append(responses, resp)
}
}
return nil
})
return responses, err
}
func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
err = manager.db.View(func(tx *bolt.Tx) error {
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
infos = append(infos, uinfo)
return nil
})
return err
})
if infos == nil {
infos = []UserInfo{}
}
return
}
func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) {
err = manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(UID)
if bucket == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
return nil
})
return
}
func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(u.UID)
if err != nil {
return err
}
if u.SessionsCap != nil {
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
return err
}
}
if u.UpRate != nil {
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
return err
}
}
if u.DownRate != nil {
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
return err
}
}
if u.UpCredit != nil {
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
return err
}
}
if u.DownCredit != nil {
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
return err
}
}
if u.ExpiryTime != nil {
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
return err
}
}
return nil
})
return
}
func (manager *localManager) DeleteUser(UID []byte) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error {
return tx.DeleteBucket(UID)
})
return
}
func (manager *localManager) Close() error {
return manager.db.Close()
}

View File

@ -1,365 +0,0 @@
package usermanager
import (
"encoding/binary"
"io/ioutil"
"math/rand"
"os"
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
)
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var mockUserInfo = UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
UpRate: JustInt64(100),
DownRate: JustInt64(1000),
UpCredit: JustInt64(10000),
DownCredit: JustInt64(100000),
ExpiryTime: JustInt64(1000000),
}
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
cleaner = func() { os.Remove(tmpDB.Name()) }
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
return mgr, cleaner
}
func TestLocalManager_WriteUserInfo(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
err := mgr.WriteUserInfo(mockUserInfo)
if err != nil {
t.Error(err)
}
got, err := mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, mockUserInfo, got)
/* Partial update */
err = mgr.WriteUserInfo(UserInfo{
UID: mockUID,
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
})
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
got, err = mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, expected, got)
}
func TestLocalManager_GetUserInfo(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("simple fetch", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo)
gotInfo, err := mgr.GetUserInfo(mockUID)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotInfo, mockUserInfo) {
t.Errorf("got wrong user info: %v", gotInfo)
}
})
t.Run("update a field", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo)
updatedUserInfo := mockUserInfo
updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
err := mgr.WriteUserInfo(updatedUserInfo)
if err != nil {
t.Error(err)
}
gotInfo, err := mgr.GetUserInfo(mockUID)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotInfo, updatedUserInfo) {
t.Errorf("got wrong user info: %v", updatedUserInfo)
}
})
t.Run("non existent user", func(t *testing.T) {
_, err := mgr.GetUserInfo(make([]byte, 16))
if err != ErrUserNotFound {
t.Errorf("expecting error %v, got %v", ErrUserNotFound, err)
}
})
}
func TestLocalManager_DeleteUser(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
_ = mgr.WriteUserInfo(mockUserInfo)
err := mgr.DeleteUser(mockUID)
if err != nil {
t.Error(err)
}
_, err = mgr.GetUserInfo(mockUID)
if err != ErrUserNotFound {
t.Error("user not deleted")
}
}
var validUserInfo = mockUserInfo
func TestLocalManager_AuthenticateUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
t.Run("normal auth", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
upRate, downRate, err := mgr.AuthenticateUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
t.Error("wrong up or down rate")
}
})
t.Run("non existent user", func(t *testing.T) {
_, _, err := mgr.AuthenticateUser(make([]byte, 16))
if err != ErrUserNotFound {
t.Error("user found")
}
})
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
_, _, err := mgr.AuthenticateUser(expiredUserInfo.UID)
if err != ErrUserExpired {
t.Error("user not expired")
}
})
t.Run("no credit", func(t *testing.T) {
creditlessUserInfo := validUserInfo
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
_ = mgr.WriteUserInfo(creditlessUserInfo)
_, _, err := mgr.AuthenticateUser(creditlessUserInfo.UID)
if err != ErrNoUpCredit && err != ErrNoDownCredit {
t.Error("user not creditless")
}
})
}
func TestLocalManager_AuthoriseNewSession(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("normal auth", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
if err != nil {
t.Error(err)
}
})
t.Run("non existent user", func(t *testing.T) {
err := mgr.AuthoriseNewSession(make([]byte, 16), AuthorisationInfo{NumExistingSessions: 0})
if err != ErrUserNotFound {
t.Error("user found")
}
})
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
if err != ErrUserExpired {
t.Error("user not expired")
}
})
t.Run("too many sessions", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
if err != ErrSessionsCapReached {
t.Error("session cap not reached")
}
})
}
func TestLocalManager_UploadStatus(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
t.Run("simple update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
update := StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 10,
DownUsage: 100,
Timestamp: mockWorldState.Now().Unix(),
}
_, err := mgr.UploadStatus([]StatusUpdate{update})
if err != nil {
t.Error(err)
}
updatedUserInfo, err := mgr.GetUserInfo(validUserInfo.UID)
if err != nil {
t.Error(err)
}
if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
t.Error("up usage incorrect")
}
if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
t.Error("down usage incorrect")
}
})
badUpdates := []struct {
name string
user UserInfo
update StatusUpdate
}{
{"out of up credit",
validUserInfo,
StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: *validUserInfo.UpCredit + 100,
DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(),
},
},
{"out of down credit",
validUserInfo,
StatusUpdate{
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 0,
DownUsage: *validUserInfo.DownCredit + 100,
Timestamp: mockWorldState.Now().Unix(),
},
},
{"expired",
UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
UpRate: JustInt64(0),
DownRate: JustInt64(0),
UpCredit: JustInt64(0),
DownCredit: JustInt64(0),
ExpiryTime: JustInt64(-1),
},
StatusUpdate{
UID: mockUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: 0,
DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(),
},
},
}
for _, badUpdate := range badUpdates {
t.Run(badUpdate.name, func(t *testing.T) {
_ = mgr.WriteUserInfo(badUpdate.user)
resps, err := mgr.UploadStatus([]StatusUpdate{badUpdate.update})
if err != nil {
t.Error(err)
}
if len(resps) == 0 {
t.Fatal("expecting responses")
}
resp := resps[0]
if resp.Action != TERMINATE {
t.Errorf("didn't terminate when %v", badUpdate.name)
}
})
}
}
func TestLocalManager_ListAllUsers(t *testing.T) {
mgr, cleaner := makeManager(t)
defer cleaner()
var wg sync.WaitGroup
var users []UserInfo
for i := 0; i < 100; i++ {
randUID := make([]byte, 16)
rand.Read(randUID)
newUser := UserInfo{
UID: randUID,
SessionsCap: JustInt32(rand.Int31()),
UpRate: JustInt64(rand.Int63()),
DownRate: JustInt64(rand.Int63()),
UpCredit: JustInt64(rand.Int63()),
DownCredit: JustInt64(rand.Int63()),
ExpiryTime: JustInt64(rand.Int63()),
}
users = append(users, newUser)
wg.Add(1)
go func() {
err := mgr.WriteUserInfo(newUser)
if err != nil {
t.Fatal(err)
}
wg.Done()
}()
}
wg.Wait()
listedUsers, err := mgr.ListAllUsers()
if err != nil {
t.Error(err)
}
sort.Slice(users, func(i, j int) bool {
return binary.BigEndian.Uint64(users[i].UID[0:8]) < binary.BigEndian.Uint64(users[j].UID[0:8])
})
sort.Slice(listedUsers, func(i, j int) bool {
return binary.BigEndian.Uint64(listedUsers[i].UID[0:8]) < binary.BigEndian.Uint64(listedUsers[j].UID[0:8])
})
if !reflect.DeepEqual(users, listedUsers) {
t.Error("listed users deviates from uploaded ones")
}
}

View File

@ -0,0 +1,106 @@
package usermanager
import (
"errors"
"log"
"net"
"sync"
"sync/atomic"
"time"
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
// for the ease of using json package
type UserInfo struct {
UID []byte
// ALL of the following fields have to be accessed atomically
SessionsCap uint32
UpRate int64
DownRate int64
UpCredit int64
DownCredit int64
ExpiryTime int64
}
type User struct {
up *Userpanel
arrUID [32]byte
*UserInfo
valve *mux.Valve
sessionsM sync.RWMutex
sessions map[uint32]*mux.Session
}
func MakeUser(up *Userpanel, uinfo *UserInfo) *User {
// this instance of valve is shared across ALL sessions of a user
valve := mux.MakeValve(uinfo.UpRate, uinfo.DownRate, &uinfo.UpCredit, &uinfo.DownCredit)
u := &User{
up: up,
UserInfo: uinfo,
valve: valve,
sessions: make(map[uint32]*mux.Session),
}
copy(u.arrUID[:], uinfo.UID)
return u
}
func (u *User) addUpCredit(delta int64) { u.valve.AddRxCredit(delta) }
func (u *User) addDownCredit(delta int64) { u.valve.AddTxCredit(delta) }
func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.SessionsCap, cap) }
func (u *User) setUpRate(rate int64) { u.valve.SetRxRate(rate) }
func (u *User) setDownRate(rate int64) { u.valve.SetTxRate(rate) }
func (u *User) setUpCredit(n int64) { u.valve.SetRxCredit(n) }
func (u *User) setDownCredit(n int64) { u.valve.SetTxCredit(n) }
func (u *User) setExpiryTime(time int64) { atomic.StoreInt64(&u.ExpiryTime, time) }
func (u *User) updateInfo(uinfo UserInfo) {
u.setSessionsCap(uinfo.SessionsCap)
u.setUpCredit(uinfo.UpCredit)
u.setDownCredit(uinfo.DownCredit)
u.setUpRate(uinfo.UpRate)
u.setDownRate(uinfo.DownRate)
u.setExpiryTime(uinfo.ExpiryTime)
}
func (u *User) PutSession(sessionID uint32, sesh *mux.Session) {
u.sessionsM.Lock()
u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
}
func (u *User) DelSession(sessionID uint32) {
u.sessionsM.Lock()
delete(u.sessions, sessionID)
if len(u.sessions) == 0 {
u.sessionsM.Unlock()
u.up.delActiveUser(u.UID)
return
}
u.sessionsM.Unlock()
}
func (u *User) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
if time.Now().Unix() > u.ExpiryTime {
return nil, false, errors.New("Expiry time passed")
}
u.sessionsM.Lock()
if sesh = u.sessions[sessionID]; sesh != nil {
u.sessionsM.Unlock()
return sesh, true, nil
} else {
if len(u.sessions) >= int(u.SessionsCap) {
u.sessionsM.Unlock()
return nil, false, errors.New("SessionsCap reached")
}
log.Printf("Creating session %v\n", sessionID)
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
return sesh, false, nil
}
}

View File

@ -1,64 +0,0 @@
package usermanager
import (
"errors"
)
type StatusUpdate struct {
UID []byte
Active bool
NumSession int
UpUsage int64
DownUsage int64
Timestamp int64
}
type MaybeInt32 *int32
type MaybeInt64 *int64
type UserInfo struct {
UID []byte
SessionsCap MaybeInt32
UpRate MaybeInt64
DownRate MaybeInt64
UpCredit MaybeInt64
DownCredit MaybeInt64
ExpiryTime MaybeInt64
}
func JustInt32(v int32) MaybeInt32 { return &v }
func JustInt64(v int64) MaybeInt64 { return &v }
type StatusResponse struct {
UID []byte
Action int
Message string
}
type AuthorisationInfo struct {
NumExistingSessions int
}
const (
TERMINATE = iota + 1
)
var ErrUserNotFound = errors.New("UID does not correspond to a user")
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified")
var ErrNoUpCredit = errors.New("No upload credit left")
var ErrNoDownCredit = errors.New("No download credit left")
var ErrUserExpired = errors.New("User has expired")
type UserManager interface {
AuthenticateUser([]byte) (int64, int64, error)
AuthoriseNewSession([]byte, AuthorisationInfo) error
UploadStatus([]StatusUpdate) ([]StatusResponse, error)
ListAllUsers() ([]UserInfo, error)
GetUserInfo(UID []byte) (UserInfo, error)
WriteUserInfo(UserInfo) error
DeleteUser(UID []byte) error
}

View File

@ -0,0 +1,471 @@
package usermanager
import (
"encoding/base64"
"encoding/binary"
"errors"
"os"
"path"
"strconv"
"sync"
"time"
"github.com/boltdb/bolt"
)
var Uint32 = binary.BigEndian.Uint32
var Uint64 = binary.BigEndian.Uint64
var PutUint16 = binary.BigEndian.PutUint16
var PutUint32 = binary.BigEndian.PutUint32
var PutUint64 = binary.BigEndian.PutUint64
type Userpanel struct {
db *bolt.DB
bakRoot string
activeUsersM sync.RWMutex
activeUsers map[[32]byte]*User
}
func MakeUserpanel(dbPath, bakRoot string) (*Userpanel, error) {
db, err := bolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
if bakRoot == "" {
os.Mkdir("db-backup", 0777)
bakRoot = "db-backup"
}
bakRoot = path.Clean(bakRoot)
up := &Userpanel{
db: db,
bakRoot: bakRoot,
activeUsers: make(map[[32]byte]*User),
}
go func() {
for {
time.Sleep(time.Second * 10)
up.updateCredits()
}
}()
return up, nil
}
// credits of all users are updated together so that there is only 1 goroutine managing it
func (up *Userpanel) updateCredits() {
up.activeUsersM.RLock()
for _, u := range up.activeUsers {
up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(u.arrUID[:])
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte("UpCredit"), i64ToB(u.valve.GetRxCredit())); err != nil {
return err
}
if err := b.Put([]byte("DownCredit"), i64ToB(u.valve.GetTxCredit())); err != nil {
return err
}
return nil
})
}
up.activeUsersM.RUnlock()
}
func (up *Userpanel) backupDB(bakFileName string) error {
bakPath := up.bakRoot + "/" + bakFileName
_, err := os.Stat(bakPath)
if err == nil {
return errors.New("Attempting to overwrite a file during backup!")
}
var bak *os.File
if os.IsNotExist(err) {
bak, err = os.Create(bakPath)
if err != nil {
return err
}
}
err = up.db.View(func(tx *bolt.Tx) error {
_, err := tx.WriteTo(bak)
if err != nil {
return err
}
return nil
})
return err
}
var ErrUserNotFound = errors.New("User does not exist in db")
var ErrUserNotActive = errors.New("User is not active")
func (up *Userpanel) GetAndActivateAdminUser(AdminUID []byte) (*User, error) {
up.activeUsersM.Lock()
var arrUID [32]byte
copy(arrUID[:], AdminUID)
if user, ok := up.activeUsers[arrUID]; ok {
up.activeUsersM.Unlock()
return user, nil
}
uinfo := UserInfo{
UID: AdminUID,
SessionsCap: 1e9,
UpRate: 1e12,
DownRate: 1e12,
UpCredit: 1e15,
DownCredit: 1e15,
ExpiryTime: 1e15,
}
user := MakeUser(up, &uinfo)
up.activeUsers[arrUID] = user
up.activeUsersM.Unlock()
return user, nil
}
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's infor
// from the db and mark it as an active user
func (up *Userpanel) GetAndActivateUser(UID []byte) (*User, error) {
up.activeUsersM.Lock()
var arrUID [32]byte
copy(arrUID[:], UID)
if user, ok := up.activeUsers[arrUID]; ok {
up.activeUsersM.Unlock()
return user, nil
}
var uinfo UserInfo
uinfo.UID = UID
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID[:])
if b == nil {
return ErrUserNotFound
}
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) // reee brackets
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
up.activeUsersM.Unlock()
return nil, err
}
u := MakeUser(up, &uinfo)
up.activeUsers[arrUID] = u
up.activeUsersM.Unlock()
return u, nil
}
func (up *Userpanel) updateDBEntryUint32(UID []byte, key string, value uint32) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte(key), u32ToB(value)); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) updateDBEntryInt64(UID []byte, key string, value int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte(key), i64ToB(value)); err != nil {
return err
}
return nil
})
return err
}
// This is used when all sessions of a user close
func (up *Userpanel) delActiveUser(UID []byte) {
var arrUID [32]byte
copy(arrUID[:], UID)
up.activeUsersM.Lock()
delete(up.activeUsers, arrUID)
up.activeUsersM.Unlock()
}
func (up *Userpanel) getActiveUser(UID []byte) *User {
var arrUID [32]byte
copy(arrUID[:], UID)
up.activeUsersM.RLock()
ret := up.activeUsers[arrUID]
up.activeUsersM.RUnlock()
return ret
}
// below are remote control utilised functions
func (up *Userpanel) listActiveUsers() [][]byte {
var ret [][]byte
up.activeUsersM.RLock()
for _, u := range up.activeUsers {
ret = append(ret, u.UID)
}
up.activeUsersM.RUnlock()
return ret
}
func (up *Userpanel) listAllUsers() []UserInfo {
var ret []UserInfo
up.db.View(func(tx *bolt.Tx) error {
tx.ForEach(func(UID []byte, b *bolt.Bucket) error {
// if we want to avoid writing every single key out,
// we would have to either make UserInfo a map,
// or use reflect.
// neither is convinient
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
ret = append(ret, uinfo)
return nil
})
return nil
})
return ret
}
func (up *Userpanel) getUserInfo(UID []byte) (UserInfo, error) {
var uinfo UserInfo
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
return uinfo, err
}
// In boltdb, the value argument for bucket.Put has to be valid for the duration
// of the transaction.
// This basically means that you cannot reuse a byte slice for two different keys
// in a transaction. So we need to allocate a fresh byte slice for each value
func u32ToB(value uint32) []byte {
quad := make([]byte, 4)
PutUint32(quad, value)
return quad
}
func i64ToB(value int64) []byte {
oct := make([]byte, 8)
PutUint64(oct, uint64(value))
return oct
}
func (up *Userpanel) addNewUser(uinfo UserInfo) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucket(uinfo.UID[:])
if err != nil {
return err
}
if err = b.Put([]byte("SessionsCap"), u32ToB(uinfo.SessionsCap)); err != nil {
return err
}
if err = b.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
return err
}
if err = b.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
return err
}
if err = b.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
return err
}
if err = b.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
return err
}
if err = b.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) delUser(UID []byte) error {
err := up.backupDB(strconv.FormatInt(time.Now().Unix(), 10) + "_pre_del_" + base64.StdEncoding.EncodeToString(UID) + ".bak")
if err != nil {
return err
}
err = up.db.Update(func(tx *bolt.Tx) error {
return tx.DeleteBucket(UID)
})
return err
}
func (up *Userpanel) syncMemFromDB(UID []byte) error {
var uinfo UserInfo
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return ErrUserNotActive
}
u.updateInfo(uinfo)
return nil
}
// the following functions will return err==nil if user is not active
func (up *Userpanel) setSessionsCap(UID []byte, cap uint32) error {
err := up.updateDBEntryUint32(UID, "SessionsCap", cap)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setSessionsCap(cap)
return nil
}
func (up *Userpanel) setUpRate(UID []byte, rate int64) error {
err := up.updateDBEntryInt64(UID, "UpRate", rate)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setUpRate(rate)
return nil
}
func (up *Userpanel) setDownRate(UID []byte, rate int64) error {
err := up.updateDBEntryInt64(UID, "DownRate", rate)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setDownRate(rate)
return nil
}
func (up *Userpanel) setUpCredit(UID []byte, n int64) error {
err := up.updateDBEntryInt64(UID, "UpCredit", n)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setUpCredit(n)
return nil
}
func (up *Userpanel) setDownCredit(UID []byte, n int64) error {
err := up.updateDBEntryInt64(UID, "DownCredit", n)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setDownCredit(n)
return nil
}
func (up *Userpanel) setExpiryTime(UID []byte, time int64) error {
err := up.updateDBEntryInt64(UID, "ExpiryTime", time)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setExpiryTime(time)
return nil
}
func (up *Userpanel) addUpCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("UpCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addUpCredit(delta)
return nil
}
func (up *Userpanel) addDownCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("DownCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addDownCredit(delta)
return nil
}

View File

@ -1,31 +0,0 @@
package usermanager
type Voidmanager struct{}
func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) {
return 0, 0, ErrMangerIsVoid
}
func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) {
return nil, ErrMangerIsVoid
}
func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) {
return []UserInfo{}, ErrMangerIsVoid
}
func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) {
return UserInfo{}, ErrMangerIsVoid
}
func (v *Voidmanager) WriteUserInfo(info UserInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) DeleteUser(UID []byte) error {
return ErrMangerIsVoid
}

View File

@ -1,44 +0,0 @@
package usermanager
import (
"testing"
"github.com/stretchr/testify/assert"
)
var v = &Voidmanager{}
func Test_Voidmanager_AuthenticateUser(t *testing.T) {
_, _, err := v.AuthenticateUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_AuthoriseNewSession(t *testing.T) {
err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_DeleteUser(t *testing.T) {
err := v.DeleteUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_GetUserInfo(t *testing.T) {
_, err := v.GetUserInfo([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_ListAllUsers(t *testing.T) {
_, err := v.ListAllUsers()
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_UploadStatus(t *testing.T) {
_, err := v.UploadStatus([]StatusUpdate{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_WriteUserInfo(t *testing.T) {
err := v.WriteUserInfo(UserInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}

View File

@ -1,223 +0,0 @@
package server
import (
"encoding/base64"
"sync"
"sync/atomic"
"time"
"github.com/cbeuw/Cloak/internal/server/usermanager"
mux "github.com/cbeuw/Cloak/internal/multiplex"
log "github.com/sirupsen/logrus"
)
const defaultUploadInterval = 1 * time.Minute
// userPanel is used to authenticate new users and book keep active users
type userPanel struct {
Manager usermanager.UserManager
activeUsersM sync.RWMutex
activeUsers map[[16]byte]*ActiveUser
usageUpdateQueueM sync.Mutex
usageUpdateQueue map[[16]byte]*usagePair
uploadInterval time.Duration
}
func MakeUserPanel(manager usermanager.UserManager) *userPanel {
ret := &userPanel{
Manager: manager,
activeUsers: make(map[[16]byte]*ActiveUser),
usageUpdateQueue: make(map[[16]byte]*usagePair),
uploadInterval: defaultUploadInterval,
}
go ret.regularQueueUpload()
return ret
}
// GetBypassUser does the same as GetUser except it unconditionally creates an ActiveUser when the UID isn't already active
func (panel *userPanel) GetBypassUser(UID []byte) (*ActiveUser, error) {
panel.activeUsersM.Lock()
defer panel.activeUsersM.Unlock()
var arrUID [16]byte
copy(arrUID[:], UID)
if user, ok := panel.activeUsers[arrUID]; ok {
return user, nil
}
user := &ActiveUser{
panel: panel,
valve: mux.UNLIMITED_VALVE,
sessions: make(map[uint32]*mux.Session),
bypass: true,
}
copy(user.arrUID[:], UID)
panel.activeUsers[user.arrUID] = user
return user, nil
}
// GetUser retrieves the reference to an ActiveUser if it's already active, or creates a new ActiveUser of specified
// UID with UserInfo queried from the UserManger, should the particular UID is allowed to connect
func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
panel.activeUsersM.Lock()
defer panel.activeUsersM.Unlock()
var arrUID [16]byte
copy(arrUID[:], UID)
if user, ok := panel.activeUsers[arrUID]; ok {
return user, nil
}
upRate, downRate, err := panel.Manager.AuthenticateUser(UID)
if err != nil {
return nil, err
}
valve := mux.MakeValve(upRate, downRate)
user := &ActiveUser{
panel: panel,
valve: valve,
sessions: make(map[uint32]*mux.Session),
}
copy(user.arrUID[:], UID)
panel.activeUsers[user.arrUID] = user
log.WithFields(log.Fields{
"UID": base64.StdEncoding.EncodeToString(UID),
}).Info("New active user")
return user, nil
}
// TerminateActiveUser terminates a user and deletes its references
func (panel *userPanel) TerminateActiveUser(user *ActiveUser, reason string) {
log.WithFields(log.Fields{
"UID": base64.StdEncoding.EncodeToString(user.arrUID[:]),
"reason": reason,
}).Info("Terminating active user")
panel.updateUsageQueueForOne(user)
user.closeAllSessions(reason)
panel.activeUsersM.Lock()
delete(panel.activeUsers, user.arrUID)
panel.activeUsersM.Unlock()
}
func (panel *userPanel) isActive(UID []byte) bool {
var arrUID [16]byte
copy(arrUID[:], UID)
panel.activeUsersM.RLock()
_, ok := panel.activeUsers[arrUID]
panel.activeUsersM.RUnlock()
return ok
}
type usagePair struct {
up *int64
down *int64
}
// updateUsageQueue zeroes the accumulated usage all ActiveUsers valve and put the usage data im usageUpdateQueue
func (panel *userPanel) updateUsageQueue() {
panel.activeUsersM.Lock()
panel.usageUpdateQueueM.Lock()
for _, user := range panel.activeUsers {
if user.bypass {
continue
}
upIncured, downIncured := user.valve.Nullify()
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
atomic.AddInt64(usage.up, upIncured)
atomic.AddInt64(usage.down, downIncured)
} else {
// if the user hasn't been added to the queue
usage = &usagePair{&upIncured, &downIncured}
panel.usageUpdateQueue[user.arrUID] = usage
}
}
panel.activeUsersM.Unlock()
panel.usageUpdateQueueM.Unlock()
}
// updateUsageQueueForOne is the same as updateUsageQueue except it only updates one user's usage
// this is useful when the user is being terminated
func (panel *userPanel) updateUsageQueueForOne(user *ActiveUser) {
// used when one particular user deactivates
if user.bypass {
return
}
upIncured, downIncured := user.valve.Nullify()
panel.usageUpdateQueueM.Lock()
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
atomic.AddInt64(usage.up, upIncured)
atomic.AddInt64(usage.down, downIncured)
} else {
usage = &usagePair{&upIncured, &downIncured}
panel.usageUpdateQueue[user.arrUID] = usage
}
panel.usageUpdateQueueM.Unlock()
}
// commitUpdate put all usageUpdates into a slice of StatusUpdate, calls Manager.UploadStatus, gets the responses
// and act to each user according to the responses
func (panel *userPanel) commitUpdate() error {
panel.usageUpdateQueueM.Lock()
statuses := make([]usermanager.StatusUpdate, 0, len(panel.usageUpdateQueue))
for arrUID, usage := range panel.usageUpdateQueue {
panel.activeUsersM.RLock()
user := panel.activeUsers[arrUID]
panel.activeUsersM.RUnlock()
var numSession int
if user != nil {
if user.bypass {
continue
}
numSession = user.NumSession()
}
status := usermanager.StatusUpdate{
UID: arrUID[:],
Active: panel.isActive(arrUID[:]),
NumSession: numSession,
UpUsage: *usage.up,
DownUsage: *usage.down,
Timestamp: time.Now().Unix(),
}
statuses = append(statuses, status)
}
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
panel.usageUpdateQueueM.Unlock()
if len(statuses) == 0 {
return nil
}
responses, err := panel.Manager.UploadStatus(statuses)
if err != nil {
return err
}
for _, resp := range responses {
var arrUID [16]byte
copy(arrUID[:], resp.UID)
switch resp.Action {
case usermanager.TERMINATE:
panel.activeUsersM.RLock()
user := panel.activeUsers[arrUID]
panel.activeUsersM.RUnlock()
if user != nil {
panel.TerminateActiveUser(user, resp.Message)
}
}
}
return nil
}
func (panel *userPanel) regularQueueUpload() {
for {
time.Sleep(panel.uploadInterval)
go func() {
panel.updateUsageQueue()
err := panel.commitUpdate()
if err != nil {
log.Error(err)
}
}()
}
}

View File

@ -1,190 +0,0 @@
package server
import (
"encoding/base64"
"io/ioutil"
"os"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/server/usermanager"
)
func TestUserPanel_BypassUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
manager, err := usermanager.MakeLocalManager(tmpDB.Name(), common.RealWorldState)
if err != nil {
t.Error("failed to make local manager", err)
}
panel := MakeUserPanel(manager)
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
user, _ := panel.GetBypassUser(UID)
user.valve.AddRx(10)
user.valve.AddTx(10)
t.Run("isActive", func(t *testing.T) {
a := panel.isActive(UID)
if !a {
t.Error("isActive returned ", a)
}
})
t.Run("updateUsageQueue", func(t *testing.T) {
panel.updateUsageQueue()
if _, inQ := panel.usageUpdateQueue[user.arrUID]; inQ {
t.Error("user in update queue")
}
})
t.Run("updateUsageQueueForOne", func(t *testing.T) {
panel.updateUsageQueueForOne(user)
if _, inQ := panel.usageUpdateQueue[user.arrUID]; inQ {
t.Error("user in update queue")
}
})
t.Run("commitUpdate", func(t *testing.T) {
err := panel.commitUpdate()
if err != nil {
t.Error("commit returned", err)
}
})
t.Run("TerminateActiveUser", func(t *testing.T) {
panel.TerminateActiveUser(user, "")
if panel.isActive(user.arrUID[:]) {
t.Error("user still active after deletion", err)
}
})
t.Run("Repeated delete", func(t *testing.T) {
panel.TerminateActiveUser(user, "")
})
err = manager.Close()
if err != nil {
t.Error("failed to close localmanager", err)
}
}
var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var validUserInfo = usermanager.UserInfo{
UID: mockUID,
SessionsCap: usermanager.JustInt32(10),
UpRate: usermanager.JustInt64(100),
DownRate: usermanager.JustInt64(1000),
UpCredit: usermanager.JustInt64(10000),
DownCredit: usermanager.JustInt64(100000),
ExpiryTime: usermanager.JustInt64(1000000),
}
func TestUserPanel_GetUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
panel := MakeUserPanel(mgr)
t.Run("normal user", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
activeUser, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
again, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Errorf("can't get existing user: %v", err)
}
if activeUser != again {
t.Error("got different references")
}
})
t.Run("non existent user", func(t *testing.T) {
_, err = panel.GetUser(make([]byte, 16))
if err != usermanager.ErrUserNotFound {
t.Errorf("expecting error %v, got %v", usermanager.ErrUserNotFound, err)
}
})
}
func TestUserPanel_UpdateUsageQueue(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
mgr, err := usermanager.MakeLocalManager(tmpDB.Name(), mockWorldState)
if err != nil {
t.Fatal(err)
}
panel := MakeUserPanel(mgr)
t.Run("normal update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
user, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
user.valve.AddTx(1)
user.valve.AddRx(2)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
if user.valve.GetRx() != 0 || user.valve.GetTx() != 0 {
t.Error("rx and tx stats are not cleared")
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
t.Error("down credit incorrect update")
}
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
t.Error("up credit incorrect update")
}
// another update
user.valve.AddTx(3)
user.valve.AddRx(4)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
t.Error("down credit incorrect update")
}
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
t.Error("up credit incorrect update")
}
})
t.Run("terminating update", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
user, err := panel.GetUser(validUserInfo.UID)
if err != nil {
t.Error(err)
}
user.valve.AddTx(*validUserInfo.DownCredit + 100)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
t.Error(err)
}
if panel.isActive(validUserInfo.UID) {
t.Error("user not terminated")
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if *updatedUinfo.DownCredit != -100 {
t.Error("down credit not updated correctly after the user has been terminated")
}
})
}

View File

@ -1,103 +0,0 @@
package server
import (
"bufio"
"bytes"
"crypto"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh"
)
type WebSocket struct{}
func (WebSocket) String() string { return "WebSocket" }
func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
var req *http.Request
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket)))
if err != nil {
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
return
}
var hiddenData []byte
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
fragments, err = WebSocket{}.unmarshalHidden(hiddenData, privateKey)
if err != nil {
err = fmt.Errorf("failed to unmarshal hidden data from WS into authFragments: %v", err)
return
}
respond = WebSocket{}.makeResponder(reqPacket, fragments.sharedSecret)
return
}
func (WebSocket) makeResponder(reqPacket []byte, sharedSecret [32]byte) Responder {
respond := func(originalConn net.Conn, sessionKey [32]byte, randSource io.Reader) (preparedConn net.Conn, err error) {
handler := newWsHandshakeHandler()
// For an explanation of the following 3 lines, see the comments in websocketAux.go
http.Serve(newWsAcceptor(originalConn, reqPacket), handler)
<-handler.finished
preparedConn = handler.conn
nonce := make([]byte, 12)
common.RandRead(randSource, nonce)
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
encryptedKey, err := common.AESGCMEncrypt(nonce, sharedSecret[:], sessionKey[:]) // 32 + 16 = 48 bytes
if err != nil {
err = fmt.Errorf("failed to encrypt reply: %v", err)
return
}
reply := append(nonce, encryptedKey...)
_, err = preparedConn.Write(reply)
if err != nil {
err = fmt.Errorf("failed to write reply: %v", err)
preparedConn.Close()
return
}
return
}
return respond
}
var ErrBadGET = errors.New("non (or malformed) HTTP GET")
func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fragments authFragments, err error) {
if len(hidden) < 96 {
err = ErrBadGET
return
}
copy(fragments.randPubKey[:], hidden[0:32])
ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:])
if !ok {
err = ErrInvalidPubKey
return
}
var sharedSecret []byte
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
}
copy(fragments.sharedSecret[:], sharedSecret)
if len(hidden[32:]) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))
return
}
copy(fragments.ciphertextWithTag[:], hidden[32:])
return
}

View File

@ -1,138 +0,0 @@
package server
import (
"errors"
"net"
"net/http"
"github.com/cbeuw/Cloak/internal/common"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
//
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
//
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
// net/http package upon receiving a request from a Conn.
//
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
// function.
//
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
// accepted.
//
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
// Accept will return error (so that the caller won't call again)
//
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
//
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
// websocket.Conn
//
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
//
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
// execution will block until the reference to util.WebSocketConn is ready.
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
// fake a conn that returns the first packet on first read
type firstBuffedConn struct {
net.Conn
firstRead bool
firstPacket []byte
}
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Conn.Read(buf)
}
type wsOnceListener struct {
done bool
c *firstBuffedConn
}
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
// http.Server. This is an acceptor that accepts only one Conn
func newWsAcceptor(conn net.Conn, first []byte) *wsOnceListener {
f := make([]byte, len(first))
copy(f, first)
return &wsOnceListener{
c: &firstBuffedConn{Conn: conn, firstPacket: f},
}
}
func (w *wsOnceListener) Accept() (net.Conn, error) {
if w.done {
return nil, errors.New("already accepted")
}
w.done = true
return w.c, nil
}
func (w *wsOnceListener) Close() error {
w.done = true
return nil
}
func (w *wsOnceListener) Addr() net.Addr {
return w.c.LocalAddr()
}
type wsHandshakeHandler struct {
conn net.Conn
finished chan struct{}
}
// the handler to turn a net.Conn into a websocket.Conn
func newWsHandshakeHandler() *wsHandshakeHandler {
return &wsHandshakeHandler{finished: make(chan struct{})}
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("failed to upgrade connection to ws: %v", err)
return
}
ws.conn = &common.WebSocketConn{Conn: c}
ws.finished <- struct{}{}
}

View File

@ -1,59 +0,0 @@
package server
import (
"bytes"
"testing"
"github.com/cbeuw/connutil"
)
func TestFirstBuffedConn_Read(t *testing.T) {
mockConn, writingEnd := connutil.AsyncPipe()
expectedFirstPacket := []byte{1, 2, 3}
firstBuffedConn := &firstBuffedConn{
Conn: mockConn,
firstRead: false,
firstPacket: expectedFirstPacket,
}
buf := make([]byte, 1024)
n, err := firstBuffedConn.Read(buf)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(expectedFirstPacket, buf[:n]) {
t.Error("first read doesn't produce given packet")
return
}
expectedSecondPacket := []byte{4, 5, 6, 7}
writingEnd.Write(expectedSecondPacket)
n, err = firstBuffedConn.Read(buf)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(expectedSecondPacket, buf[:n]) {
t.Error("second read doesn't produce subsequently written packet")
return
}
}
func TestWsAcceptor(t *testing.T) {
mockConn := connutil.Discard()
expectedFirstPacket := []byte{1, 2, 3}
wsAcceptor := newWsAcceptor(mockConn, expectedFirstPacket)
_, err := wsAcceptor.Accept()
if err != nil {
t.Error(err)
return
}
_, err = wsAcceptor.Accept()
if err == nil {
t.Error("accepting second time doesn't return error")
}
}

View File

@ -1,575 +0,0 @@
package test
import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"math/rand"
"net"
"sync"
"testing"
"time"
"github.com/cbeuw/Cloak/internal/client"
"github.com/cbeuw/Cloak/internal/common"
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server"
"github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert"
log "github.com/sirupsen/logrus"
)
const numConns = 200 // -race option limits the number of goroutines to 8192
func serveTCPEcho(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
log.Error(err)
return
}
go func(conn net.Conn) {
_, err := io.Copy(conn, conn)
if err != nil {
conn.Close()
log.Error(err)
return
}
}(conn)
}
}
func serveUDPEcho(listener *connutil.PipeListener) {
for {
conn, err := listener.ListenPacket("udp", "")
if err != nil {
log.Error(err)
return
}
const bufSize = 32 * 1024
go func(conn net.PacketConn) {
defer conn.Close()
buf := make([]byte, bufSize)
for {
r, _, err := conn.ReadFrom(buf)
if err != nil {
log.Error(err)
return
}
w, err := conn.WriteTo(buf[:r], nil)
if err != nil {
log.Error(err)
return
}
if r != w {
log.Error("written not eqal to read")
return
}
}
}(conn)
}
}
var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=")
var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=")
var basicUDPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "openvpn",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 4,
UDP: true,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
}
var basicTCPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 4,
UDP: false,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
BrowserSig: "firefox",
}
var singleplexTCPConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "shadowsocks",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
NumConn: 0,
UDP: false,
Transport: "direct",
RemoteHost: "fake.com",
RemotePort: "9999",
LocalHost: "127.0.0.1",
LocalPort: "9999",
BrowserSig: "safari",
}
func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) {
lcl, rmt, auth, err := rawConfig.ProcessRawConfig(state)
if err != nil {
log.Fatal(err)
}
return lcl, rmt, auth
}
func basicServerState(ws common.WorldState) *server.State {
var serverConfig = server.RawConfig{
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
BindAddr: []string{"fake.com:9999"},
BypassUID: [][]byte{bypassUID[:]},
RedirAddr: "fake.com:9999",
PrivateKey: privateKey,
KeepAlive: 15,
CncMode: false,
}
state, err := server.InitState(serverConfig, ws)
if err != nil {
log.Fatal(err)
}
return state
}
type mockUDPDialer struct {
addrCh chan *net.UDPAddr
raddr *net.UDPAddr
}
func (m *mockUDPDialer) Dial(network, address string) (net.Conn, error) {
if m.raddr == nil {
m.raddr = <-m.addrCh
}
return net.DialUDP("udp", nil, m.raddr)
}
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) {
// redirecting web server
// ^
// |
// |
// redirFromCkServerL
// |
// |
// proxy client ----proxyToCkClientD----> ck-client ------> ck-server ----proxyFromCkServerL----> proxy server
// ^
// |
// |
// netToCkServerD
// |
// |
// whatever connection initiator (including a proper ck-client)
netToCkServerD, ckServerListener := connutil.DialerListener(10 * 1024)
clientSeshMaker := func() *mux.Session {
ai := ai
quad := make([]byte, 4)
common.RandRead(ai.WorldState.Rand, quad)
ai.SessionId = binary.BigEndian.Uint32(quad)
return client.MakeSession(rcc, ai, netToCkServerD)
}
var proxyToCkClientD common.Dialer
if ai.Unordered {
// We can only "dial" a single UDP connection as we can't send packets from different context
// to a single UDP listener
addrCh := make(chan *net.UDPAddr, 1)
mDialer := &mockUDPDialer{
addrCh: addrCh,
}
acceptor := func() (*net.UDPConn, error) {
laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, err := net.ListenUDP("udp", laddr)
addrCh <- conn.LocalAddr().(*net.UDPAddr)
return conn, err
}
go client.RouteUDP(acceptor, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
proxyToCkClientD = mDialer
} else {
var proxyToCkClientL *connutil.PipeListener
proxyToCkClientD, proxyToCkClientL = connutil.DialerListener(10 * 1024)
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, rcc.Singleplex, clientSeshMaker)
}
// set up server
ckServerToProxyD, proxyFromCkServerL := connutil.DialerListener(10 * 1024)
ckServerToWebD, redirFromCkServerL := connutil.DialerListener(10 * 1024)
serverState.ProxyDialer = ckServerToProxyD
serverState.RedirDialer = ckServerToWebD
go server.Serve(ckServerListener, serverState)
return proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, nil
}
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
testData := make([]byte, msgLen)
rand.Read(testData)
// we cannot call t.Fatalf in concurrent contexts
n, err := conn.Write(testData)
if n != msgLen {
t.Errorf("written only %v, err %v", n, err)
return
}
recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Errorf("failed to read back: %v", err)
return
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("echoed data not correct")
return
}
}(conn)
}
wg.Wait()
}
func TestUDP(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("simple send", func(t *testing.T) {
pxyClientConn, err := proxyToCkClientD.Dial("udp", "")
if err != nil {
t.Error(err)
}
const testDataLen = 1500
testData := make([]byte, testDataLen)
rand.Read(testData)
n, err := pxyClientConn.Write(testData)
if n != testDataLen {
t.Errorf("wrong length sent: %v", n)
}
if err != nil {
t.Error(err)
}
pxyServerConn, err := proxyFromCkServerL.ListenPacket("", "")
if err != nil {
t.Error(err)
}
recvBuf := make([]byte, testDataLen+100)
r, _, err := pxyServerConn.ReadFrom(recvBuf)
if err != nil {
t.Error(err)
}
if !bytes.Equal(testData, recvBuf[:r]) {
t.Error("read wrong data")
}
})
const echoMsgLen = 1024
t.Run("user echo", func(t *testing.T) {
go serveUDPEcho(proxyFromCkServerL)
var conn [1]net.Conn
conn[0], err = proxyToCkClientD.Dial("udp", "")
if err != nil {
t.Error(err)
}
runEchoTest(t, conn[:], echoMsgLen)
})
}
func TestTCPSingleplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
const echoMsgLen = 1 << 16
go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
runEchoTest(t, []net.Conn{proxyConn1}, echoMsgLen)
user, err := sta.Panel.GetUser(ai.UID[:])
if err != nil {
t.Fatalf("failed to fetch user: %v", err)
}
if user.NumSession() != 1 {
t.Error("no session were made on first connection establishment")
}
proxyConn2, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
if user.NumSession() != 2 {
t.Error("no extra session were made on second connection establishment")
}
// Both conns should work
runEchoTest(t, []net.Conn{proxyConn1, proxyConn2}, echoMsgLen)
proxyConn1.Close()
assert.Eventually(t, func() bool {
return user.NumSession() == 1
}, time.Second, 10*time.Millisecond, "first session was not closed on connection close")
// conn2 should still work
runEchoTest(t, []net.Conn{proxyConn2}, echoMsgLen)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil {
t.Fatal(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
}
func TestTCPMultiplex(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("user echo single", func(t *testing.T) {
for i := 0; i < 18; i += 2 {
dataLen := 1 << i
writeData := make([]byte, dataLen)
rand.Read(writeData)
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
go serveTCPEcho(proxyFromCkServerL)
conn, err := proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
n, err := conn.Write(writeData)
if err != nil {
t.Error(err)
}
if n != dataLen {
t.Errorf("write length doesn't match up: %v, expected %v", n, dataLen)
}
recvBuf := make([]byte, dataLen)
_, err = io.ReadFull(conn, recvBuf)
if err != nil {
t.Error(err)
}
if !bytes.Equal(writeData, recvBuf) {
t.Error("echoed data incorrect")
}
})
}
})
const echoMsgLen = 16384
t.Run("user echo", func(t *testing.T) {
go serveTCPEcho(proxyFromCkServerL)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = proxyToCkClientD.Dial("", "")
if err != nil {
t.Error(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
})
t.Run("redir echo", func(t *testing.T) {
go serveTCPEcho(redirFromCkServerL)
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
conns[i], err = netToCkServerD.Dial("", "")
if err != nil {
t.Error(err)
}
}
runEchoTest(t, conns[:], echoMsgLen)
})
}
func TestClosingStreamsFromProxy(t *testing.T) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
for clientConfigName, clientConfig := range map[string]client.RawConfig{"basic": basicTCPConfig, "singleplex": singleplexTCPConfig} {
clientConfig := clientConfig
clientConfigName := clientConfigName
t.Run(clientConfigName, func(t *testing.T) {
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("closing from server", func(t *testing.T) {
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Close()
assert.Eventually(t, func() bool {
_, err := clientConn.Read(make([]byte, 16))
return err != nil
}, time.Second, 10*time.Millisecond, "closing stream on server side is not reflected to the client")
})
t.Run("closing from client", func(t *testing.T) {
// closing stream on client side
clientConn, _ := proxyToCkClientD.Dial("", "")
clientConn.Write(make([]byte, 16))
serverConn, _ := proxyFromCkServerL.Accept()
clientConn.Close()
assert.Eventually(t, func() bool {
_, err := serverConn.Read(make([]byte, 16))
return err != nil
}, time.Second, 10*time.Millisecond, "closing stream on client side is not reflected to the server")
})
t.Run("send then close", func(t *testing.T) {
testData := make([]byte, 24*1024)
rand.Read(testData)
clientConn, _ := proxyToCkClientD.Dial("", "")
go func() {
clientConn.Write(testData)
// it takes time for this written data to be copied asynchronously
// into ck-server's domain. If the pipe is closed before that, read
// by ck-client in RouteTCP will fail as we have closed it.
time.Sleep(700 * time.Millisecond)
clientConn.Close()
}()
readBuf := make([]byte, len(testData))
serverConn, err := proxyFromCkServerL.Accept()
if err != nil {
t.Errorf("failed to accept a connection delievering data sent before closing: %v", err)
}
_, err = io.ReadFull(serverConn, readBuf)
if err != nil {
t.Errorf("failed to read data sent before closing: %v", err)
}
})
})
}
}
func BenchmarkIntegration(b *testing.B) {
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState)
const bufSize = 16 * 1024
encryptionMethods := map[string]byte{
"plain": mux.EncryptionMethodPlain,
"chacha20-poly1305": mux.EncryptionMethodChaha20Poly1305,
"aes-256-gcm": mux.EncryptionMethodAES256GCM,
"aes-128-gcm": mux.EncryptionMethodAES128GCM,
}
for name, method := range encryptionMethods {
b.Run(name, func(b *testing.B) {
ai.EncryptionMethod = method
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
b.Fatal(err)
}
b.Run("single stream bandwidth", func(b *testing.B) {
more := make(chan int, 10)
go func() {
// sender
writeBuf := make([]byte, bufSize+100)
serverConn, _ := proxyFromCkServerL.Accept()
for {
serverConn.Write(writeBuf)
<-more
}
}()
// receiver
clientConn, _ := proxyToCkClientD.Dial("", "")
readBuf := make([]byte, bufSize)
clientConn.Write([]byte{1}) // to make server accept
b.SetBytes(bufSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
io.ReadFull(clientConn, readBuf)
// ask for more
more <- 0
}
})
b.Run("single stream latency", func(b *testing.B) {
clientConn, _ := proxyToCkClientD.Dial("", "")
buf := []byte{1}
clientConn.Write(buf)
serverConn, _ := proxyFromCkServerL.Accept()
serverConn.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
clientConn.Write(buf)
serverConn.Read(buf)
}
})
})
}
}

View File

@ -1 +0,0 @@
package test

104
internal/util/util.go Normal file
View File

@ -0,0 +1,104 @@
package util
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
prand "math/rand"
"net"
"strconv"
)
func AESEncrypt(iv []byte, key []byte, plaintext []byte) []byte {
block, _ := aes.NewCipher(key)
ciphertext := make([]byte, len(plaintext))
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ciphertext, plaintext)
return ciphertext
}
func AESDecrypt(iv []byte, key []byte, ciphertext []byte) []byte {
ret := make([]byte, len(ciphertext))
copy(ret, ciphertext) // Because XORKeyStream is inplace, but we don't want the input to be changed
block, _ := aes.NewCipher(key)
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ret, ret)
return ret
}
// BtoInt converts a byte slice into int in Big Endian order
// Uint methods from binary package can be used, but they are messy
func BtoInt(b []byte) int {
var mult uint = 1
var sum uint
length := uint(len(b))
var i uint
for i = 0; i < length; i++ {
sum += uint(b[i]) * (mult << ((length - i - 1) * 8))
}
return int(sum)
}
// PsudoRandBytes returns a byte slice filled with psudorandom bytes generated by the seed
func PsudoRandBytes(length int, seed int64) []byte {
r := prand.New(prand.NewSource(seed))
ret := make([]byte, length)
r.Read(ret)
return ret
}
// ReadTLS reads TLS data according to its record layer
func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
// TCP is a stream. Multiple TLS messages can arrive at the same time,
// a single message can also be segmented due to MTU of the IP layer.
// This function guareentees a single TLS message to be read and everything
// else is left in the buffer.
i, err := io.ReadFull(conn, buffer[:5])
if err != nil {
return
}
dataLength := BtoInt(buffer[3:5])
if dataLength > len(buffer) {
err = errors.New("Reading TLS message: message size greater than buffer. message size: " + strconv.Itoa(dataLength))
return
}
left := dataLength
readPtr := 5
for left != 0 {
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
// if left = buffer size, the entire buffer is all there left to read
// if left < buffer size (i.e. multiple messages came together),
// only the message we want is read
i, err = io.ReadFull(conn, buffer[readPtr:readPtr+left])
if err != nil {
return
}
left -= i
readPtr += i
}
n = 5 + dataLength
return
}
// AddRecordLayer adds record layer to data
func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(input)))
ret := make([]byte, 5+len(input))
copy(ret[0:1], typ)
copy(ret[1:3], ver)
copy(ret[3:5], length)
copy(ret[5:], input)
return ret
}
// PeelRecordLayer peels off the record layer
func PeelRecordLayer(data []byte) []byte {
ret := data[5:]
return ret
}

View File

@ -1,37 +0,0 @@
#!/usr/bin/env bash
set -eu
go install github.com/mitchellh/gox@latest
mkdir -p release
rm -f ./release/*
if [ -z "$v" ]; then
echo "Version number cannot be null. Run with v=[version] release.sh"
exit 1
fi
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
osarch="!darwin/arm !darwin/386"
echo "Compiling:"
os="windows linux darwin"
arch="amd64 386 arm arm64 mips mips64 mipsle mips64le"
pushd cmd/ck-client
CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output"
CGO_ENABLED=0 GOOS="linux" GOARCH="mips" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mips_softfloat-"${v}"
CGO_ENABLED=0 GOOS="linux" GOARCH="mipsle" GOMIPS="softfloat" go build -ldflags "-X main.version=${v}" -o ck-client-linux-mipsle_softfloat-"${v}"
mv ck-client-* ../../release
popd
os="linux"
arch="amd64 386 arm arm64"
pushd cmd/ck-server
CGO_ENABLED=0 gox -ldflags "-X main.version=${v}" -os="$os" -arch="$arch" -osarch="$osarch" -output="$output"
mv ck-server-* ../../release
popd
sha256sum release/*

View File

@ -1,13 +0,0 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
],
"packageRules": [
{
"packagePatterns": ["*"],
"excludePackagePatterns": ["utls"],
"enabled": false
}
]
}