package org.springblade.modules.sse.server; import lombok.extern.slf4j.Slf4j; import org.springframework.http.MediaType; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @Slf4j public class SSEServer { /** * 当前连接数 */ private static AtomicInteger count = new AtomicInteger(0); private static Map sseEmitterMap = new ConcurrentHashMap<>(); public static SseEmitter connect(String userId){ //设置超时时间,0表示不过期,默认是30秒,超过时间未完成会抛出异常 SseEmitter sseEmitter = new SseEmitter(0L); //注册回调 sseEmitter.onCompletion(completionCallBack(userId)); sseEmitter.onError(errorCallBack(userId)); sseEmitter.onTimeout(timeOutCallBack(userId)); sseEmitterMap.put(userId,sseEmitter); //数量+1 count.getAndIncrement(); log.info("create new sse connect ,current user:{}",userId); log.info("count",count.getAndIncrement()); return sseEmitter; } /** * 给指定用户发消息 */ public static void sendMessage(String userId, String message){ if(sseEmitterMap.containsKey(userId)){ try{ sseEmitterMap.get(userId).send(message); }catch (IOException e){ log.error("user id:{}, send message error:{}",userId,e.getMessage()); e.printStackTrace(); } } } /** * 想多人发送消息,组播 */ public static void groupSendMessage(String groupId, String message){ if(sseEmitterMap!=null&&!sseEmitterMap.isEmpty()){ sseEmitterMap.forEach((k,v) -> { try{ if(k.startsWith(groupId)){ v.send(message, MediaType.APPLICATION_JSON); } }catch (IOException e){ log.error("user id:{}, send message error:{}",groupId,message); removeUser(k); } }); } } /** * 批量发送消息 * @param message */ public static void batchSendMessage(String message) { sseEmitterMap.forEach((k,v)->{ try{ v.send(message,MediaType.APPLICATION_JSON); }catch (IOException e){ log.error("user id:{}, send message error:{}",k,e.getMessage()); removeUser(k); } }); } /** * 群发消息 */ public static void batchSendMessage(String message, Set userIds){ userIds.forEach(userId->sendMessage(userId,message)); } /** * 用户离线删除用户 * @param userId */ public static void removeUser(String userId){ sseEmitterMap.remove(userId); //数量-1 count.getAndDecrement(); log.info("remove user id:{}",userId); } public static List getIds(){ return new ArrayList<>(sseEmitterMap.keySet()); } public static int getUserCount(){ return count.intValue(); } /** * 结束回调 * @param userId * @return */ private static Runnable completionCallBack(String userId) { return () -> { log.info("结束连接,{}",userId); removeUser(userId); }; } /** * 超时回调 * @param userId * @return */ private static Runnable timeOutCallBack(String userId){ return ()->{ log.info("连接超时,{}",userId); removeUser(userId); }; } /** * 错误回调 * @param userId * @return */ private static Consumer errorCallBack(String userId){ return throwable -> { log.error("连接异常,{}",userId); removeUser(userId); }; } }