blob: 8532f7fee668bca90894ce1f6ffeb6ed4183a380 [file] [log] [blame]
// Copyright 2016 The Android Open Source Project
//
// This software is licensed under the terms of the GNU General Public
// License version 2, as published by the Free Software Foundation, and
// may be copied, distributed, and modified under those terms.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
#include "android/base/async/AsyncSocketServer.h"
#include "android/base/async/Looper.h"
#include "android/base/async/ScopedSocketWatch.h"
#include "android/base/async/ThreadLooper.h"
#include "android/base/Log.h"
#include "android/base/sockets/ScopedSocket.h"
#include "android/base/sockets/SocketUtils.h"
#include <memory>
namespace android {
namespace base {
using FdWatch = android::base::Looper::FdWatch;
namespace {
class BaseSocketServer : public AsyncSocketServer {
public:
BaseSocketServer(Looper* looper,
int port,
ConnectCallback connectCallback,
int socket4,
int socket6)
: mLooper(looper), mPort(port), mConnectCallback(connectCallback) {
if (socket4 >= 0) {
mBoundSocket4.reset(looper->createFdWatch(socket4, onAccept, this));
CHECK(mBoundSocket4.get());
}
if (socket6 >= 0) {
mBoundSocket6.reset(looper->createFdWatch(socket6, onAccept, this));
CHECK(mBoundSocket6.get());
}
}
virtual ~BaseSocketServer() = default;
virtual int port() const override { return mPort; }
virtual void startListening() override {
if (mBoundSocket4) {
mBoundSocket4->wantRead();
}
if (mBoundSocket6) {
mBoundSocket6->wantRead();
}
};
virtual void stopListening() override {
if (mBoundSocket4) {
mBoundSocket4->dontWantRead();
}
if (mBoundSocket6) {
mBoundSocket6->dontWantRead();
}
};
virtual LoopbackMode getListenMode() const override {
return mBoundSocket4 ? (mBoundSocket6 ? kIPv4AndIPv6 : kIPv4)
: (mBoundSocket6 ? kIPv6 : kNone);
}
private:
// Called when an i/o event happens on the bound socket. Typically
// a read event indicates a new client connection.
static void onAccept(void* opaque, int fd, unsigned events) {
auto server = reinterpret_cast<BaseSocketServer*>(opaque);
if (events & FdWatch::kEventRead) {
int clientFd = socketAcceptAny(fd);
if (clientFd < 0) {
// This can happen when the client closed the connection
// before we could process it. Just exit and try until the
// next time.
PLOG(WARNING) << "Error when accepting host connection";
return;
}
// Stop listening now, this lets the callback call startListening()
// if it wants to allow new connections immediately. It also means
// we need to call it explicitly in case of error.
server->stopListening();
if (!server->mConnectCallback(clientFd)) {
// Assume the callback printed any necessary error message.
socketClose(clientFd);
server->startListening();
return;
}
}
}
ScopedSocketWatch mBoundSocket4;
ScopedSocketWatch mBoundSocket6;
Looper* mLooper = nullptr;
int mPort = 0;
ConnectCallback mConnectCallback;
};
} // namespace
// static
std::unique_ptr<AsyncSocketServer> AsyncSocketServer::createTcpLoopbackServer(
int port,
ConnectCallback connectCallback,
LoopbackMode mode,
Looper* looper) {
if (!looper) {
looper = ThreadLooper::get();
}
ScopedSocket sock4;
ScopedSocket sock6;
// If |port| is 0, we let the system find a random free port, but
// in the case of kIPv4AndIPv6, we need to ensure that it is possible
// to bind to both ports. Loop a few times to ensure that this is the
// case.
int tries = (port == 0 && mode == kIPv4AndIPv6) ? 5 : 1;
int port0 = port;
bool success = false;
for (; tries > 0; tries--) {
port = port0;
success = true;
if ((mode & kIPv4) != 0) {
sock4.reset(socketTcp4LoopbackServer(port));
if (!sock4.valid()) {
if (!(mode & kIPv4Optional)) {
success = false;
}
} else if (!port) {
port = socketGetPort(sock4.get());
}
}
if ((mode & kIPv6) != 0) {
sock6.reset(socketTcp6LoopbackServer(port));
if (!sock6.valid()) {
if (!(mode & kIPv6Optional)) {
success = false;
}
} else if (!port) {
port = socketGetPort(sock6.get());
}
}
if (success) { break; }
}
AsyncSocketServer* result = nullptr;
if (success && (sock4.valid() || sock6.valid())) {
result = new BaseSocketServer(looper, port, connectCallback,
sock4.release(), sock6.release());
}
return std::unique_ptr<AsyncSocketServer>(result);
}
} // namespace base
} // namespace android