xuhy
1 天以前 a960c432d78dfe5f0ef07295d0210ddb09340e12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
package com.ruoyi.web.controller.webSocket;
 
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
 
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
 
/**
 * 基于AppUser的WebSocket连接管理器
 * 支持按用户ID管理连接,一个用户可以有多个设备连接
 */
@Component
@Slf4j
public class WebSocketUserConnectionManager {
    
    /**
     * 存储用户的所有连接
     * Key: 用户ID (appUserId)
     * Value: 该用户的所有连接Channel列表
     */
    private static final ConcurrentHashMap<String, Set<Channel>> userConnections = new ConcurrentHashMap<>();
    
    /**
     * 存储连接与用户的映射关系
     * Key: ChannelId.asLongText()
     * Value: 用户ID (appUserId)
     */
    private static final ConcurrentHashMap<String, String> channelToUser = new ConcurrentHashMap<>();
    
    /**
     * 存储所有连接
     * Key: ChannelId.asLongText()
     * Value: Channel
     */
    private static final ConcurrentHashMap<String, Channel> allConnections = new ConcurrentHashMap<>();
    
    /**
     * 添加用户连接
     * @param appUserId 用户ID
     * @param channel WebSocket连接通道
     */
    public void addUserConnection(String appUserId, Channel channel) {
        String channelId = channel.id().asLongText();
        
        // 添加到用户连接映射
        userConnections.computeIfAbsent(appUserId, k -> ConcurrentHashMap.newKeySet())
                .add(channel);
        
        // 添加到连接用户映射
        channelToUser.put(channelId, appUserId);
        
        // 添加到所有连接
        allConnections.put(channelId, channel);
        
        log.info("用户连接已添加,用户ID: {}, ChannelId: {}, 该用户连接数: {}", 
                appUserId, channelId, userConnections.get(appUserId).size());
    }
    
    /**
     * 移除用户连接
     * @param channel WebSocket连接通道
     */
    public void removeUserConnection(Channel channel) {
        String channelId = channel.id().asLongText();
        String appUserId = channelToUser.get(channelId);
        
        if (appUserId != null) {
            // 从用户连接中移除
            Set<Channel> userChannels = userConnections.get(appUserId);
            if (userChannels != null) {
                userChannels.remove(channel);
                // 如果用户没有连接了,删除用户映射
                if (userChannels.isEmpty()) {
                    userConnections.remove(appUserId);
                }
            }
            
            // 从连接用户映射中移除
            channelToUser.remove(channelId);
            
            log.info("用户连接已移除,用户ID: {}, ChannelId: {}, 该用户剩余连接数: {}", 
                    appUserId, channelId, userChannels != null ? userChannels.size() : 0);
        }
        
        // 从所有连接中移除
        allConnections.remove(channelId);
    }
    
    /**
     * 向指定用户的所有设备发送消息
     * @param appUserId 用户ID
     * @param message 消息内容
     * @return 成功发送的设备数
     */
    public int sendMessageToUser(String appUserId, String message) {
        Set<Channel> userChannels = userConnections.get(appUserId);
        if (userChannels == null || userChannels.isEmpty()) {
            log.warn("用户 {} 没有活跃连接", appUserId);
            return 0;
        }
        
        int successCount = 0;
        for (Channel channel : userChannels) {
            if (channel != null && channel.isActive()) {
                try {
                    channel.writeAndFlush(new TextWebSocketFrame(message));
                    successCount++;
                } catch (Exception e) {
                    log.error("向用户 {} 发送消息失败,连接ID: {}, 错误: {}", 
                            appUserId, channel.id().asLongText(), e.getMessage());
                }
            }
        }
        
        log.info("向用户 {} 发送消息完成,成功发送到 {} 个设备,消息内容: {}", 
                appUserId, successCount, message);
        return successCount;
    }
    
    /**
     * 向指定连接发送消息
     * @param channelId 连接ID
     * @param message 消息内容
     * @return 是否发送成功
     */
    public boolean sendMessageToChannel(String channelId, String message) {
        Channel channel = allConnections.get(channelId);
        if (channel != null && channel.isActive()) {
            try {
                channel.writeAndFlush(new TextWebSocketFrame(message));
                log.info("消息已发送到连接: {}, 内容: {}", channelId, message);
                return true;
            } catch (Exception e) {
                log.error("发送消息失败,连接ID: {}, 错误: {}", channelId, e.getMessage());
                return false;
            }
        } else {
            log.warn("连接不存在或已关闭,ChannelId: {}", channelId);
            return false;
        }
    }
    
    /**
     * 向所有连接广播消息
     * @param message 消息内容
     * @return 成功发送的连接数
     */
    public int broadcastMessage(String message) {
        int successCount = 0;
        for (Channel channel : allConnections.values()) {
            if (channel != null && channel.isActive()) {
                try {
                    channel.writeAndFlush(new TextWebSocketFrame(message));
                    successCount++;
                } catch (Exception e) {
                    log.error("广播消息失败,连接ID: {}, 错误: {}", 
                            channel.id().asLongText(), e.getMessage());
                }
            }
        }
        log.info("广播消息完成,成功发送到 {} 个连接,消息内容: {}", successCount, message);
        return successCount;
    }
    
    /**
     * 向多个用户发送消息
     * @param appUserIds 用户ID列表
     * @param message 消息内容
     * @return 成功发送的用户数
     */
    public int sendMessageToUsers(List<String> appUserIds, String message) {
        int successUserCount = 0;
        for (String appUserId : appUserIds) {
            int deviceCount = sendMessageToUser(appUserId, message);
            if (deviceCount > 0) {
                successUserCount++;
            }
        }
        log.info("向 {} 个用户发送消息完成,成功发送到 {} 个用户", appUserIds.size(), successUserCount);
        return successUserCount;
    }
    
    /**
     * 获取用户的所有连接
     * @param appUserId 用户ID
     * @return 连接列表
     */
    public Set<Channel> getUserConnections(String appUserId) {
        return userConnections.getOrDefault(appUserId, Collections.emptySet());
    }
    
    /**
     * 获取用户连接数
     * @param appUserId 用户ID
     * @return 连接数
     */
    public int getUserConnectionCount(String appUserId) {
        Set<Channel> userChannels = userConnections.get(appUserId);
        return userChannels != null ? userChannels.size() : 0;
    }
    
    /**
     * 获取总连接数
     * @return 总连接数
     */
    public int getTotalConnectionCount() {
        return allConnections.size();
    }
    
    /**
     * 获取在线用户数
     * @return 在线用户数
     */
    public int getOnlineUserCount() {
        return userConnections.size();
    }
    
    /**
     * 获取所有在线用户ID
     * @return 用户ID集合
     */
    public Set<String> getOnlineUserIds() {
        return userConnections.keySet();
    }
    
    /**
     * 检查用户是否在线
     * @param appUserId 用户ID
     * @return 是否在线
     */
    public boolean isUserOnline(String appUserId) {
        Set<Channel> userChannels = userConnections.get(appUserId);
        if (userChannels == null || userChannels.isEmpty()) {
            return false;
        }
        // 检查是否有活跃连接
        return userChannels.stream().anyMatch(Channel::isActive);
    }
    
    /**
     * 检查连接是否存在
     * @param channelId 连接ID
     * @return 是否存在
     */
    public boolean isConnectionExists(String channelId) {
        Channel channel = allConnections.get(channelId);
        return channel != null && channel.isActive();
    }
    
    /**
     * 根据连接ID获取用户ID
     * @param channelId 连接ID
     * @return 用户ID
     */
    public String getUserIdByChannelId(String channelId) {
        return channelToUser.get(channelId);
    }
    
    /**
     * 获取连接统计信息
     * @return 统计信息
     */
    public Map<String, Object> getConnectionStats() {
        Map<String, Object> stats = new HashMap<>();
        stats.put("totalConnections", getTotalConnectionCount());
        stats.put("onlineUsers", getOnlineUserCount());
        stats.put("onlineUserIds", getOnlineUserIds());
        
        // 统计每个用户的连接数
        Map<String, Integer> userConnectionCounts = new HashMap<>();
        for (Map.Entry<String, Set<Channel>> entry : userConnections.entrySet()) {
            userConnectionCounts.put(entry.getKey(), entry.getValue().size());
        }
        stats.put("userConnectionCounts", userConnectionCounts);
        
        return stats;
    }
}