SpringBoot + Netty 实现 TCP拆包粘包处理、TCP恶意连接拦截
================ 代码实现过程 ================
NettyServer:创建TCP服务
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
@Slf4j
public class NettyServer {
private void startServer() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workGroup);
b.channel(NioServerSocketChannel.class);
b.option(ChannelOption.SO_BACKLOG, 1024);
b.childOption(ChannelOption.TCP_NODELAY, true);
b.childOption(ChannelOption.SO_KEEPALIVE, true);
b.childHandler(new DoorInitChannel());
ChannelFuture sync = b.bind(9701).sync();
sync.channel().closeFuture().sync();
} catch (Exception e) {
log.error("TCP server init faild: "+e.getMessage();
e.printStackTrace();
} finally {
cleanUp(bossGroup, workGroup);
}
}
private void cleanUp(EventLoopGroup bossGroup, EventLoopGroup workGroup) {
bossGroup.shutdownGracefully();
workGroup.shutdownGracefully();
}
public void init() {
new Thread(() -> {
startServer();
}).start();
}
}
ChannelInit:通道连接事件
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
public class ChannelInit extends ChannelInitializer<SocketChannel> {
@Override
protected void initChannel(SocketChannel channel){
channel.pipeline().addLast("decoder", new BytePacketDecoder());
channel.pipeline().addLast(new IdleStateHandler(60, 0, 0));
channel.pipeline().addLast("handler", new MessageHandler());
}
}
NettyEvent :Netty连接事件 实体类
import com.fasterxml.jackson.annotation.JsonFormat;
import lombok.Data;
import java.util.Date;
import java.util.concurrent.atomic.AtomicInteger;
@Data
public class NettyEvent {
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
private Date connectDate;
private AtomicInteger connectCount;
public NettyEvent() {
}
public NettyEvent(Date connectDate, AtomicInteger connectCount) {
this.connectDate = connectDate;
this.connectCount = connectCount;
}
}
ClientEventManage :客户端事件管理
import com.google.common.cache.*;
import io.netty.channel.Channel;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Date;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class ClientEventManage {
public static String CONNECT_CODE = "TCP_CONNECT_";
private static LoadingCache<String, NettyEvent> connectEvent;
public static void connectCheck(Channel ctx) {
NettyEvent cahche = null;
try {
String clientIP = locationInfo(ctx);
ExceptionUtil.isBlank(clientIP, "TCP client ip get is null !");
String key = (CONNECT_CODE + clientIP);
cahche = queryConnectCache(key);
if (!Objects.isNull(cahche)) {
int connectCount = cahche.getConnectCount().incrementAndGet();
log.info("TCP client iP:[" + getClientIp(ctx) + "] connect count:[" + connectCount + "], first connect time: " + DateUtil.getFormatTime(cahche.getConnectDate()));
if (30 <= connectCount) {
cahche.getConnectCount().decrementAndGet();
log.error("TCP client iP:[" + getClientIp(ctx) + "] connect number exceed the 30 limit!");
ctx.close();
}
}
else {
cahche = new NettyEvent(new Date(), new AtomicInteger(0));
connectEvent().put(key, cahche);
}
} catch (Exception e) {
e.printStackTrace();
throw new BusinessException("TCP connect check faild: " + e.getMessage());
}
}
public static String getClientIp(Channel ctx) {
try {
InetSocketAddress ipSocket = (InetSocketAddress) ctx.remoteAddress();
InetAddress address = ipSocket.getAddress();
StringBuffer value = new StringBuffer(address.getHostAddress());
value.append(":");
value.append(ipSocket.getPort());
return value.toString();
} catch (Exception e) {
e.printStackTrace();
log.error("Get tcp request location Faild: " + e.getMessage());
}
return null;
}
public static String locationInfo(Channel ctx) {
try {
InetSocketAddress ipSocket = (InetSocketAddress) ctx.remoteAddress();
InetAddress address = ipSocket.getAddress();
StringBuffer value = new StringBuffer(address.getHostAddress());
return value.toString();
} catch (Exception e) {
e.printStackTrace();
log.error("Get tcp request location Faild: " + e.getMessage());
}
return null;
}
public static NettyEvent queryConnectCache(String key) {
NettyEvent cahceValue = null;
try {
cahceValue = connectEvent().get(key);
} catch (Exception e) {
}
return cahceValue;
}
public static LoadingCache<String, NettyEvent> connectEvent() {
try {
if (Objects.isNull(connectEvent)) {
synchronized (LoadingCache.class) {
if (Objects.isNull(connectEvent)) {
try {
connectEvent = buildCache(new CacheLoader<String, NettyEvent>() {
@Override
public NettyEvent load(String key) {
return null;
}
}, 60, 60);
} catch (Exception e) {
log.error("TCP connect cache, build faild: " + e.getMessage());
}
}
}
}
} catch (Exception e) {
log.error("TCP connect cache exception" + e.getMessage());
}
return connectEvent;
}
private static LoadingCache<String, NettyEvent> buildCache(CacheLoader<String, NettyEvent> cacheLoader, long expireAfterAccess, long expireAfterWrite) {
try {
LoadingCache<String, NettyEvent> cache = CacheBuilder.newBuilder()
.maximumSize(100000)
.expireAfterAccess(expireAfterAccess, TimeUnit.SECONDS)
.expireAfterWrite(expireAfterWrite, TimeUnit.SECONDS)
.removalListener(new RemovalListener<String, NettyEvent>() {
@Override
public void onRemoval(RemovalNotification<String, NettyEvent> rn) {
}
})
.recordStats()
.build(cacheLoader);
return cache;
} catch (Exception e) {
log.error("Loading cache build exception: " + e.getMessage());
return null;
}
}
public static String getFormatTime(Date date) {
SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
return simpleDateFormat.format(date);
}
}
MessageHandler:消息处理器
@Slf4j
public class MessageHandler extends SimpleChannelInboundHandler<Object> {
@Override
public void channelRegistered(ChannelHandlerContext ctx) {
log.info("TCP client:" + ctx.channel() + " =========》》》》》request!");
ClientEventManage.connectCheck(ctx.channel());
}
@Override
public void channelActive(ChannelHandlerContext ctx){
log.info("TCP client:" + ctx.channel() + " connect success!");
}
@Override
public void channelInactive(ChannelHandlerContext ctx){
log.error("TCP client:" + ctx.channel() + " connect close!");
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
try {
log.info("TCP messge :"+msg);
} catch (Exception e) {
log.error("TCP message read exception: " + e.getMessage());
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
switch (e.state()) {
case READER_IDLE:
log.error("TCP client:" + ctx.channel() + " connect lose efficacy!");
ctx.channel().close();
break;
case WRITER_IDLE:
break;
case ALL_IDLE:
break;
default:
break;
}
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("TCP client:" + ctx.channel() + " business exceptions: "+ cause.getMessage());
}
}
MessageDecoder :消息解码器
TCP数据 拆包 粘包 详情图
import com.za.edu.bean.DataPacket;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.util.List;
public class MessageDecoder extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) throws Exception {
byte[] head = null;
Integer theLength = null;
Integer bodyLength = null;
Integer practicaLength = null;
try {
byteBuf.markReaderIndex();
practicaLength = byteBuf.readableBytes();
ExceptionUtil.isTrue(practicaLength < ((数据包(头)有效长度)), "Total data length error");
boolean flag = false;
while (byteBuf.isReadable()) {
if (byteBuf.readByte() == (数据包(头)标识 02)) {
flag = true;
ExceptionUtil.isTrue(practicaLength < ((数据包(头)有效长度)+(数据包(尾)有效长度)) - 1,
"Actual data length error!");
int index = (byteBuf.readerIndex() - 1);
byteBuf.readerIndex(index);
break;
}
Thread.sleep(10);
}
ExceptionUtil.isTrue(!flag, "Not found data head ! ");
head = new byte[(数据包(头)默认有效长度)];
byteBuf.readBytes(head);
theLength = 根据TCP客户端 推送的消息,并解析获取到 当前数据包的实际长度;
bodyLength= 根据TCP客户端 推送的消息,减去头尾长度 得到有效数据内容的长度;
ExceptionUtil.isTrue(服务端 实际接收总长度 < theLength, "Total length is less than actual length ! ");
} catch (Exception e) {
e.printStackTrace();
byteBuf.resetReaderIndex();
return;
}
byte[] body = null;
byte[] foot = null;
boolean ifAdhesion = (practicaLength > theLength);
try {
body = new byte[bodyLength];
byteBuf.readBytes(body);
foot = new byte[数据包(尾)默认长度];
byteBuf.readBytes(foot);
Object message= 根据客户端消息业务规则,解析出对应的数据内容;
ExceptionUtil.isNull(message, "TCP message get is null!");
out.add(tcpEvent);
} catch (Exception e) {
e.printStackTrace();
} finally {
if (!ifAdhesion) {
byteBuf.clear();
}
}
}
}
ExceptionUtil :异常工具类
import lombok.extern.slf4j.Slf4j;
import org.junit.platform.commons.util.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.regex.Pattern;
@Slf4j
public class ExceptionUtil {
public static void isTrue(Boolean boole, String msg) {
if (boole) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNotTrue(Boolean boole, String msg) {
if (!boole) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNull(Object obj, String msg) {
if (obj == null) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNotNull(Object obj, String msg) {
if (obj != null) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isMatcher(String regex, String str, String msg) {
if (!Pattern.matches(regex, str)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isBlank(String obj, String msg) {
if (StringUtils.isBlank(obj)) {
log.debug(msg);
new Exception(msg);
}
}
public static void isEmpty(Collection value, String msg) {
if (CollectionUtils.isEmpty(value)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isEmpty(String str, String msg) {
if(StringUtils.isBlank(str)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNotEmpty(Collection value, String msg) {
if (!CollectionUtils.isEmpty(value)) {
log.debug(msg);
new Exception(msg);
}
}
public static void isNotBlank(String value, String msg){
if(StringUtils.isNotBlank(value)){
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNonZero(Integer value, String msg) {
if (! ((null == value) || (value == 0))) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNonZero(Long value, String msg) {
if (!((null == value) || (value == 0))) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNonZero(Double value, String msg) {
if (!((null == value) || (value == 0))) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNullOrZero(Integer value, String msg) {
if ((null == value) || (value == 0)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNullOrZero(Long value, String msg) {
if ((null == value || value == 0)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
public static void isNullOrZero(Double value, String msg) {
if ((null == value) || (value == 0)) {
log.debug(msg);
new ExceptionManage(msg);
}
}
static class ExceptionManage extends RuntimeException {
protected int code = 500;
protected String msg;
public ExceptionManage(int code,String msg) {
super(msg);
this.msg = msg;
this.code = code;
}
public ExceptionManage(String msg) {
super(msg);
this.msg = msg;
}
public ExceptionManage() {
this("服务器出了点意外...");
this.msg = "服务器出了点意外...";
}
public ExceptionManage(Exception cause) {
super(cause);
}
public ExceptionManage(String msg,Exception cause) {
super(msg,cause);
this.msg = msg;
}
public int getCode() {
return code;
}
public void setCode(int code) {
this.code = code;
}
public String getMsg() {
return msg;
}
public void setMsg(String msg) {
this.msg = msg;
}
@Override
public String toString() {
if(StringUtils.isNotBlank(msg)) {
return "code:" + getCode() + ",msg:" + msg + ";";
}
return super.toString() + ";code:" + getCode();
}
}
}
模拟TCP恶意连接 测试结果
Jmeter模拟TCP连接,向我们TCP服务异步发起120个连接 测试结果: 只根据IP来做校验 不带端口号,因为客户端发起的 每一次TCP连接端口号都不同!
|