一、第一种方式的缺点
? ? ? ? 为了防止恶意占用网络连接资源,需要在websockt连接加入拦截器,但是在查找了大量网络资源后,根据注解@ServerEndpoint进行websocket连接的方式进行拦截我没有找到,其中有一篇博文是在@ServerEndPoint中加入自定义的配置器。
附:文章出处
去实现ServerEndpointConfig.Configurator内部类中的modifyHandShake方法进行拦截,我尝试了一下后,没有第二种方式简单,而且第二种方式具有通用性,较第一种方式要好一点,可以根据自己的情况进行选择。
二、第二种方式实现Websocket集群及通信
????????集群只需要加入SpringCloud依赖加入注册中心,再使用网关进行同一转发、负载均衡即可搭建集群,同上一篇博文一致,在此篇不做展示。????????
效果展示:
????????如果说第一种方式是一个websocket连接一个解决方案,那么第二种方式就是websocket集体注册,共享解决方案。具体代码如下
1、目录结构
??
2、pom依赖
<parent>
<artifactId>spring-boot-starter-parent</artifactId>
<groupId>org.springframework.boot</groupId>
<version>2.1.6.RELEASE</version>
</parent>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-beans</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aspects</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-redis</artifactId>
<version>1.4.1.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter</artifactId>
<version>2.1.0.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.9</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-amqp</artifactId>
</dependency>
</dependencies>
3、代码根据结构从上至下
①SpringWebSocketConfig
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@Configuration
@EnableWebMvc
@EnableWebSocket
public class SpringWebSocketConfig extends WebMvcConfigurerAdapter implements WebSocketConfigurer {
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
//.setAllowedOrigins("*") 允许跨域访问
registry.addHandler(webSocketHandler(),"/webSocket").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
registry.addHandler(webSocketHandler(), "/sockjs/socketServer.do").addInterceptors(new SpringWebSocketHandlerInterceptor()).setAllowedOrigins("*");
}
@Bean
public TextWebSocketHandler webSocketHandler(){
return new SpringWebSocketHandler();
}
}
②Controller层没有实际意义,不写也可以
③SpringWebSocketHandler
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author secondj
* @Date 2021/11/15 10:26
*/
@Component
public class SpringWebSocketHandler extends TextWebSocketHandler {
@Autowired
FanoutSender fanoutSender;
private static final AtomicInteger ati = new AtomicInteger();
public static final ConcurrentHashMap<String,WebSocketSession> map = new ConcurrentHashMap<>();
private static Logger logger = LoggerFactory.getLogger(SpringWebSocketHandler.class);
public SpringWebSocketHandler() {
// TODO Auto-generated constructor stub
}
/**
* 连接成功后给前端发送的消息,会触发页面上onopen方法
*/
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// TODO Auto-generated method stub
Object sid = session.getAttributes().get("tdt_sid");
map.put(sid.toString(),session);
int num = ati.incrementAndGet();
logger.info("connect to the websocket success......当前数量:{}",num);
//这块会实现自己业务,比如,当用户登录后,会把离线消息推送给用户
TextMessage returnMessage = new TextMessage("连接成功");
session.sendMessage(returnMessage);
}
/**
* 关闭连接时触发
* 调用session.close也会触发
*/
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
logger.debug("websocket connection closed......");
String sid= (String) session.getAttributes().get("tdt_sid");
logger.info("用户"+sid+"已退出!");
if(map.containsKey(sid)){
map.remove(sid);
}
int num = ati.decrementAndGet();
logger.info("剩余在线用户:{}"+num);
}
/**
* js调用websocket.send时候,会调用该方法
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
String payload = message.getPayload();
WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
String toSid = webSocketVO.getToUserId();
if(map.containsKey(toSid)){
map.get(toSid).sendMessage(message);
}else {
ConcurrentHashMap<String,String> sendMap = new ConcurrentHashMap();
sendMap.put(toSid,message.getPayload());
logger.info("getPayLoad():{}",message.getPayload());
logger.info("messge:{}",message.toString());
logger.info("map数据:{}",sendMap.toString());
fanoutSender.sendMessage(sendMap.toString());
}
}
public void sendMessage(TextMessage message) throws Exception {
WebSocketMessage message2self = new TextMessage("该用户没有上线".getBytes());
String payload = message.getPayload();
WebSocketVO webSocketVO = JSONObject.parseObject(payload, WebSocketVO.class);
String toSid = webSocketVO.getToUserId();
if(map.containsKey(toSid)){
map.get(toSid).sendMessage(message);
}
}
/**
* 代理发生异常时执行
*/
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
if(session.isOpen()){session.close();}
logger.debug("websocket connection closed......");
Map<String, Object> attributes = session.getAttributes();
Object sid = attributes.get("sid");
map.remove(sid);
WebSocketMessage webSocketMessage = new TextMessage("发生异常错误".getBytes());
session.sendMessage(webSocketMessage);
}
public boolean supportsPartialMessages() {
return false;
}
/**
* 给某个用户发送消息
*
* @param sid
* @param message
*/
public void sendMessageToUser(String sid, TextMessage message) {
Set<Map.Entry<String, WebSocketSession>> entries = map.entrySet();
for (Map.Entry<String, WebSocketSession> entry : entries) {
if (entry.getValue().getAttributes().get("tdt_sid").equals(sid)) {
try {
if (entry.getValue().isOpen()) {
entry.getValue().sendMessage(message);
}
} catch (IOException e) {
e.printStackTrace();
}
break;
}
}
}
public void sendMessageToUser(String sid, String message) {
TextMessage messages = new TextMessage(message.getBytes());
this.sendMessageToUser(sid,messages);
}
/**
* 给所有在线用户发送消息
*
* @param message
*/
public void sendMessageToUsers(TextMessage message) throws IOException {
Set<Map.Entry<String, WebSocketSession>> entries = map.entrySet();
for (Map.Entry<String, WebSocketSession> entry : entries) {
WebSocketSession session = entry.getValue();
if(session.isOpen()){
session.sendMessage(message);
}
}
}
}
④SpringWebSocketHandlerInterceptor
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import javax.servlet.http.HttpSession;
import java.util.Map;
/**
* @author secondj
* @Date 2021/11/11 15:52
*/
@Component
public class SpringWebSocketHandlerInterceptor extends HttpSessionHandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) throws Exception {
//TODO 获取url传递的参数,通过attributes在Interceptor处理结束后传递给WebSocketHandler
//TODO WebSocketHandler可以通过WebSocketSession的getAttributes()方法获取参数
//设置session值这里根据自己需求设置即可
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
HttpSession session = servletRequest.getServletRequest().getSession(true);
String sid = servletRequest.getServletRequest().getParameter("userId");
String token = servletRequest.getServletRequest().getParameter("token");
if(!token.equals("EXAM_PERMISSION")){
return false;
}
if (session != null) {
String userName = (String) session.getAttribute("exam_sid");
if (userName == null) {
userName = sid;
}
attributes.put("exam_sid",userName);
return true;
}
}
return false;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Exception ex) {
// TODO Auto-generated method stub
super.afterHandshake(request, response, wsHandler, ex);
}
}
⑤RetryCache该类是用来缓存队列消息和消息重发的,属于消息安全
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author secondj
* @Date 2021/11/15 16:28
*/
@Slf4j
@Component
public class RetryCache {
private SendMessage sendMessage;
private boolean stop = false;
private Map<String,MessageWithTime> map = new ConcurrentHashMap<>();
private AtomicInteger id = new AtomicInteger();
@Data
@AllArgsConstructor
@NoArgsConstructor
private static class MessageWithTime{
long time;
Object message;
}
public void sender(SendMessage sendMessage){
this.sendMessage = sendMessage;
startRetry();
}
public String generaterId(){
return ""+id.incrementAndGet();
}
public void add(String id,Object message){
map.put(id,new MessageWithTime(System.currentTimeMillis(),message));
}
public void del(String id){
map.remove(id);
}
//多线程发送消息
private void startRetry(){
new Thread(()->{
while(!stop){
try {
Thread.sleep(System.currentTimeMillis());
} catch (InterruptedException e) {
e.printStackTrace();
}
long now = System.currentTimeMillis();
for (String key : map.keySet()) {
MessageWithTime messageWithTime = map.get(key);
if(null != messageWithTime){
if(messageWithTime.getTime()+ 3 * Constant.VALID_TIME < now){
log.info("send message failed after 3 min " + messageWithTime);
del(key);
}else if (messageWithTime.getTime() + Constant.VALID_TIME< now) {
DetailResult detailRes = sendMessage.send(messageWithTime.getMessage());
if (detailRes.isSuccess()) {
del(key);
}
}
}
}
}
}).start();
}
}
⑥RabbitMQConfig
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.*;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.connection.CorrelationData;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* @author secondj
* @Date 2021/11/15 16:20
*/
@Configuration
@Slf4j
public class RabbitMQConfig {
@Value("${spring.rabbitmq.addresses}")
public String addresses;
@Value("${spring.rabbitmq.port}")
public String port;
@Value("${spring.rabbitmq.username}")
private String username;
@Value("${spring.rabbitmq.password}")
private String password;
@Value("${spring.rabbitmq.virtual-host}")
private String virtualHost;
@Value("${spring.rabbitmq.publisher-confirms}")
private boolean publisherConfirms;
@Value("${tdt.queue}")
public String queue;
@Value("${tdt.exchange}")
public String exchange;
@Autowired
RetryCache retryCache;
/**
* 创建连接工厂
* @return
*/
@Bean
public ConnectionFactory connectionFactory(){
CachingConnectionFactory connectionFactory = new CachingConnectionFactory();
connectionFactory.setHost("127.0.0.1");
connectionFactory.setPort(Integer.valueOf(port));
connectionFactory.setUsername(username);
connectionFactory.setPassword(password);
connectionFactory.setVirtualHost(virtualHost);
connectionFactory.setPublisherConfirms(publisherConfirms);
return connectionFactory;
}
@Bean
public Queue queueTdt(){
log.info("创建队列成功:{}",queue);
return new Queue(queue);
}
@Bean
public FanoutExchange fanoutExchangeTdt(){
log.info("创建交换机成功:{}",exchange);
return new FanoutExchange(exchange);
}
@Bean
public Binding bindingTdt(){
Binding bind = BindingBuilder.bind(queueTdt()).to(fanoutExchangeTdt());
log.info("交换机队列绑定成功");
return bind;
}
@Bean
public RabbitTemplate rabbitTemplate(){
RabbitTemplate rabbitTemplate = new RabbitTemplate(connectionFactory());
//TODO 失败通知
rabbitTemplate.setMandatory(true);
//TODO 失败回调
rabbitTemplate.setReturnCallback(returnCallback());
//TODO 发送方确认
rabbitTemplate.setConfirmCallback(confirmCallback());
return rabbitTemplate;
}
//===============发送方确认===============
public RabbitTemplate.ConfirmCallback confirmCallback(){
return new RabbitTemplate.ConfirmCallback(){
@Override
public void confirm(CorrelationData correlationData,
boolean ack, String cause) {
if (ack) {
log.info("发送者确认发送给mq成功");
} else {
//处理失败的消息
log.info("发送者发送给mq失败,考虑重发:"+cause);
}
}
};
}
//===============失败通知===============
public RabbitTemplate.ReturnCallback returnCallback(){
return new RabbitTemplate.ReturnCallback(){
@Override
public void returnedMessage(Message message,
int replyCode,
String replyText,
String exchange,
String routingKey) {
log.info("无法路由的消息,需要考虑另外处理。");
log.info("Returned replyText:"+replyText);
log.info("Returned exchange:"+exchange);
log.info("Returned routingKey:"+routingKey);
String msgJson = new String(message.getBody());
log.info("Returned Message:"+msgJson);
}
};
}
}
⑦Constant
/**
* @author secondj
* @Date 2021/11/15 16:31
*/
public class Constant {
public static final long VALID_TIME = 3600l;
}
⑧ReceiveMessage和SendMessage
/**
* @author secondj
* @Date 2021/11/15 16:45
*/
public interface ReceiveMessage {
DetailResult receive(Object obj);
}
/**
* @author secondj
* @Date 2021/11/15 16:29
*/
public interface SendMessage {
DetailResult send(Object obj);
}
⑨FanoutReceiver
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.amqp.support.AmqpHeaders;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.Header;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
/**
* @author secondj
* @Date 2021/11/15 16:22
*/
@Component
@Slf4j
public class FanoutReceiver {
private static Logger logger = LoggerFactory.getLogger(FanoutReceiver.class);
@Autowired
SpringWebSocketHandler handler;
@RabbitHandler
@RabbitListener(queues = "queue_mqsocket")//动态绑定
public void receiveMessage(String jsonObject, Channel channel, @Header(AmqpHeaders.DELIVERY_TAG) long tag) {
//返回字符串
try{
log.info("队列接收到消息:{}",jsonObject);
jsonObject = jsonObject.replace("=",":");
Map<String,String> mapstr = JSONObject.parseObject(jsonObject, Map.class);
Set<Map.Entry<String, String>> entries = mapstr.entrySet();
for (Map.Entry<String, String> entry : entries) {
Object obj = entry.getKey();
String sid = obj.toString();
Object objs = entry.getValue();
String message = objs.toString();
if(SpringWebSocketHandler.map.containsKey(sid)){
handler.sendMessageToUser(sid,message);
}
}
}catch (JSONException e){
e.printStackTrace();
return;
}
try {
channel.basicAck(tag,false);
} catch (IOException e) {
e.printStackTrace();
}
}
}
⑩DetailResult
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* @author secondj
* @Date 2021/11/15 16:30
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class DetailResult {
private boolean flag;
public Object message;
public boolean isSuccess() {
return flag == true;
}
}
⑩①FanoutSender
import cn.tdt.rabbitmq.cache.RetryCache;
import cn.tdt.rabbitmq.function.SendMessage;
import cn.tdt.rabbitmq.result.DetailResult;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* @author secondj
* @Date 2021/11/15 16:22
*/
@Component
@Slf4j
public class FanoutSender implements SendMessage {
RetryCache retryCache = new RetryCache();
@Autowired
RabbitTemplate rabbitTemplate;
//发送消息
public void sendMessage(Object obj) {
rabbitTemplate.setExchange("exchange_mqsocket");
log.info("【消息发送者】发送消息到fanout交换机"+ JSONObject.toJSONString(obj));
try{
send(obj);
}catch (RuntimeException ex){
ex.printStackTrace();
log.info("send failed"+ex);
try{
send(obj);
}catch (RuntimeException e){
e.printStackTrace();
log.info("retry send failed"+e);
}
}
}
//客户端发送消息前,先在本地进行缓存
@Override
public DetailResult send(Object message) {
try{
String id = retryCache.generaterId();
retryCache.add(id,message);
rabbitTemplate.convertAndSend("exchange_mqsocket","",message);
// rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
// rabbitTemplate.correlationConvertAndSend(message,new CorrelationData(id));
}catch (Exception e){
return new DetailResult(false,"");
}
return new DetailResult(true,"");
}
}
⑩②WebSocketVO
import lombok.Data;
/**
* @author secondj
* @Date 2021/11/15 14:53
*/
@Data
public class WebSocketVO {
private String toUserId;
private String msgType;
private String msgInfo;
}
⑩③启动类
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
/**
* @author secondj
* @Date 2021/11/11 15:37
*/
@SpringBootApplication
public class SSOApplicationRun {
public static void main(String[] args) {
SpringApplication.run(SSOApplicationRun.class,args);
}
}
前端代码
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket</title>
</head>
<body>
<h3>hello socket</h3>
<p>【userId】:<div><input id="userId" name="userId" type="text"></div>
<p>【toUserId】:<div><input id="toUserId" name="toUserId" type="text"></div>
<p>【msgType】:<div><input id="msgType" name="msgType" type="text"></div>
<p>【msgInfo】:<div><input id="msgInfo" name="msgInfo" type="text"></div>
<p>【操作】:<div><button onclick="openSocket()">开启socket</button></div>
<p>【操作】:<div><button onclick="sendMessage()">发送消息</button></div>
</body>
<script>
var socket;
function openSocket() {
if(typeof(WebSocket) == "undefined") {
console.log("您的浏览器不支持WebSocket");
}else{
console.log("您的浏览器支持WebSocket");
//实现化WebSocket对象,指定要连接的服务器地址与端口 建立连接
var userId = document.getElementById('userId').value;
//这里采用get请求的方式发送请求,拦截器可以根据getParameter方式获取参数值
var socketUrl="ws://localhost:8080/ctl/webSocket?token=TDT_PERMISSION&userId="+userId;
console.log(socketUrl);
if(socket != null){
socket.close();
socket=null;
}
socket = new WebSocket(socketUrl);
//打开事件
socket.onopen = function() {
console.log("websocket已打开");
//socket.send("这是来自客户端的消息" + location.href + new Date());
};
//获得消息事件
socket.onmessage = function(msg) {
var serverMsg = "收到服务端信息:" + msg.data;
console.log(serverMsg);
//发现消息进入 开始处理前端触发逻辑
};
//关闭事件
socket.onclose = function() {
console.log("websocket已关闭");
};
//发生了错误事件
socket.onerror = function() {
console.log("websocket发生了错误");
}
}
}
function sendMessage() {
if(typeof(WebSocket) == "undefined") {
console.log("您的浏览器不支持WebSocket");
}else {
// console.log("您的浏览器支持WebSocket");
var toUserId = document.getElementById('toUserId').value;
var msgInfo = document.getElementById('msgInfo').value;
var msgType = document.getElementById('msgType').value;
var msg = '{"toUserId":"'+toUserId+'","msgInfo":"'+msgInfo+'","msgType":"'+msgType+'"}';
console.log(msg);
socket.send(msg);
}
}
</script>
</html>
遇到的问题:
? ? ? ? 在往消息队列中发送消息时,如果发送的消息是对象,会在接收消息时自动添加Properties字段属性,影响json转换,解决方案为把发送的消息对象转换为字符串对象或者json串。
待优化的问题:
? ? ? ? 在单点登录系统中,拦截websocket请求后,未登录用户进行重定向。
|