import * as StompJS from "@stomp/stompjs";
import difference from "lodash/difference";
import find from "lodash/find";
import intersection from "lodash/intersection";
import { useEffect, useRef, useState } from "react";

export interface ITopic {
    topic: string;
    onMessage: (m: StompJS.Message) => void;
}

interface ISubscription {
    topic: string;
    subscription: StompJS.StompSubscription;
}

const useStompWS = ({
                        url,
                        topics,
                        connectHeaders,
                        subscribeHeaders,
                        reconnectDelay = 1000,
                        heartbeatIncoming = 10000,
                        heartbeatOutgoing = 10000,
                        onConnect,
                        onDisconnect,
                        onWebSocketError
                    }: {
    url: string;
    topics: ITopic[];
    connectHeaders?: StompJS.StompHeaders;
    subscribeHeaders?: StompJS.StompHeaders;
    reconnectDelay?: number;
    heartbeatIncoming?: number;
    heartbeatOutgoing?: number;
    onConnect?: () => void;
    onDisconnect?: () => void;
    onWebSocketError?: () => void;
}): void => {
    const [connected, setConnected] = useState(false);
    const [client, setClient] = useState(new StompJS.Client());
    const [subscriptions, setSubscriptions] = useState<ISubscription[]>([]);
    const topicsRef = useRef<ITopic[]>([]);

    topicsRef.current = topics;

    useEffect(() => {
        setSubscriptions([]);
        setConnected(false);

        const c = new StompJS.Client({brokerURL: url});
        c.onStompError = (frame: StompJS.Frame) => {
            console.error(`Broker reported error: ${frame.headers.message}`);
            console.error(`Additional details: ${frame.body}`);
        };
        c.activate();
        setClient(c);

        return () => {
            c.deactivate();
        };
    }, [url]);

    useEffect((): void => {
        if (!connected) {
            return;
        }

        const currentTopics = topics.map((t) => t.topic);
        const subscribedTopics = subscriptions.map((s) => s.topic);
        const newSubscriptions: ISubscription[] = [];
        let hasChange = false;

        // keep recurring subscriptions
        intersection(subscribedTopics, currentTopics).forEach((t) => {
            const keptSub = find(subscriptions, ({topic}) => {
                return topic === t;
            });
            if (keptSub) {
                newSubscriptions.push(keptSub);
            }
        });

        // unsubscribe from deleted topics
        difference(subscribedTopics, currentTopics).forEach((t) => {
            hasChange = true;
            const removedSub = find(subscriptions, ({topic}) => {
                return topic === t;
            });
            if (removedSub) {
                removedSub.subscription.unsubscribe();
            }
        });

        // subscribe to new topics
        difference(currentTopics, subscribedTopics).forEach((t) => {
            hasChange = true;
            const newSub = find(topics, ({topic}) => {
                return topic === t;
            });
            if (newSub && client) {
                newSubscriptions.push({
                    topic: t,
                    subscription: client.subscribe(
                        newSub.topic,
                        (m: StompJS.Message): void => {
                            const messagedTopic = find(
                                topicsRef.current,
                                ({topic}) => {
                                    return topic === newSub.topic;
                                }
                            );
                            if (messagedTopic) {
                                messagedTopic.onMessage(m);
                            }
                        },
                        subscribeHeaders
                    )
                });
            }
        });
        if (hasChange) {
            setSubscriptions(newSubscriptions);
        }
    }, [subscribeHeaders, topics, client, connected, subscriptions]);

    if (connectHeaders) {
        client.connectHeaders = connectHeaders;
    }
    client.heartbeatIncoming = heartbeatIncoming;
    client.heartbeatOutgoing = heartbeatOutgoing;
    client.reconnectDelay = reconnectDelay;
    client.onConnect = () => {
        if (connected) {
            // if the websocket was already connected
            if (onDisconnect) {
                onDisconnect();
            }
            setSubscriptions([]);
        }
        if (onConnect) {
            onConnect();
        }
        setConnected(true);
    };
    client.onDisconnect = () => {
        if (onDisconnect) {
            onDisconnect();
        }
        setConnected(false);
        setSubscriptions([]);
    };
    client.onWebSocketError = () => {
        if (onWebSocketError) {
            onWebSocketError();
        }
        setConnected(false);
        setSubscriptions([]);
    };
};

export default useStompWS;
