diff --git a/prudp_connection.go b/prudp_connection.go index 11630aa..c205595 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -261,9 +261,7 @@ func (pc *PRUDPConnection) startHeartbeat() { // * If the heartbeat still did not restart, assume the // * connection is dead and clean up pc.pingKickTimer = time.AfterFunc(maxSilenceTime, func() { - pc.cleanup() // * "removed" event is dispatched here - - endpoint.deleteConnectionByID(pc.ID) + endpoint.cleanupConnection(pc) }) }) } diff --git a/prudp_endpoint.go b/prudp_endpoint.go index f4f17a1..c1a2a0f 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -105,11 +105,24 @@ func (pep *PRUDPEndPoint) EmitError(err *Error) { } } -// deleteConnectionByID deletes the connection with the specified ID -func (pep *PRUDPEndPoint) deleteConnectionByID(cid uint32) { - pep.Connections.DeleteIf(func(key string, value *PRUDPConnection) bool { - return value.ID == cid +// cleanupConnection cleans up and deletes a connection from this endpoint. Will lock the Connections mutex - make sure +// you don't hold it during a call, or this will deadlock +func (pep *PRUDPEndPoint) cleanupConnection(connection *PRUDPConnection) { + discriminator := fmt.Sprintf("%s-%d-%d", connection.Socket.Address.String(), connection.StreamType, connection.StreamID) + + found := false + pep.Connections.RunAndDelete(discriminator, func(key string, conn *PRUDPConnection) { + found = true }) + + // * Probably this connection is on a different PRUDPEndPoint + if !found { + logger.Warningf("Tried to delete connection %v (ID %v) but it doesn't exist!", discriminator, connection.ID) + } + + // * We can't do this during RunAndDelete, since we hold the Connections mutex then + // * This way we avoid any recursive locking + connection.cleanup() } func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *SocketConnection) { @@ -417,10 +430,7 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { streamID := packet.SourceVirtualPortStreamID() discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID) if connection, ok := pep.Connections.Get(discriminator); ok { - // * We make sure to update the connection state here because we could still be attempting to - // * resend packets. - connection.cleanup() - pep.Connections.Delete(discriminator) + pep.cleanupConnection(connection) } pep.emit("disconnect", packet) @@ -698,7 +708,7 @@ func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection { var connection *PRUDPConnection pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { - if pc.pid.Value() == pid { + if pc.pid.Value() == pid && pc.ConnectionState == StateConnected { connection = pc return true } diff --git a/timeout_manager.go b/timeout_manager.go index 2cacb9e..22fb943 100644 --- a/timeout_manager.go +++ b/timeout_manager.go @@ -53,10 +53,10 @@ func (tm *TimeoutManager) start(packet PRUDPPacketInterface) { } if tm.packets.Has(packet.SequenceID()) { + endpoint := packet.Sender().Endpoint().(*PRUDPEndPoint) + // * This is `<` instead of `<=` for accuracy with observed behavior, even though we're comparing send count vs _resend_ max if packet.SendCount() < tm.streamSettings.MaxPacketRetransmissions { - endpoint := packet.Sender().Endpoint().(*PRUDPEndPoint) - packet.incrementSendCount() packet.setSentAt(time.Now()) rto := endpoint.ComputeRetransmitTimeout(packet) @@ -76,9 +76,7 @@ func (tm *TimeoutManager) start(packet PRUDPPacketInterface) { server.sendRaw(connection.Socket, data) } else { // * Packet has been retried too many times, consider the connection dead - connection.Lock() - defer connection.Unlock() - connection.cleanup() + endpoint.cleanupConnection(connection) } } } diff --git a/websocket_server.go b/websocket_server.go index a25c644..ece53e5 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -22,25 +22,26 @@ func (wseh *wsEventHandler) OnOpen(socket *gws.Conn) { } func (wseh *wsEventHandler) OnClose(wsConn *gws.Conn, _ error) { - connections := make([]*PRUDPConnection, 0) - // * Loop over all connections on all endpoints wseh.prudpServer.Endpoints.Each(func(streamid uint8, pep *PRUDPEndPoint) bool { - return pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { + connections := make([]*PRUDPConnection, 0) + + pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { if pc.Socket.Address == wsConn.RemoteAddr() { connections = append(connections, pc) } return false }) - }) - // * We cannot modify a MutexMap while looping over it - // * since the mutex is locked. We first need to grab - // * the entries we want to delete, and then loop over - // * them here to actually clean them up - for _, connection := range connections { - connection.cleanup() // * "removed" event is dispatched here - } + // * We cannot modify a MutexMap while looping over it + // * since the mutex is locked. We first need to grab + // * the entries we want to delete, and then loop over + // * them here to actually clean them up + for _, connection := range connections { + pep.cleanupConnection(connection) // * "removed" event is dispatched here + } + return false + }) } func (wseh *wsEventHandler) OnPing(socket *gws.Conn, payload []byte) {