From f9837f01e843c00ca5911cb1a935a8cb5095de0f Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Tue, 29 Apr 2025 09:36:38 -0500 Subject: [PATCH] Refactor pcap analysis by introducing IndexedCapture for efficient data handling and querying --- enrich.py | 41 ++----- enrichment/indexed_capture.py | 223 ++++++++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+), 28 deletions(-) create mode 100644 enrichment/indexed_capture.py diff --git a/enrich.py b/enrich.py index 39a50f6..3618ff3 100755 --- a/enrich.py +++ b/enrich.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import argparse -import csv import pyshark from statistics import mean from collections import defaultdict @@ -20,6 +19,7 @@ from enrichment.csv_handler import ( ) from enrichment.merge_ssid_summaries import merge_ssid_summaries import time +from enrichment.indexed_capture import IndexedCapture def parse_args(): parser = argparse.ArgumentParser() @@ -126,26 +126,12 @@ def analyze_pcap(pcapng_path, start_ts, end_ts, ap_bssid, ap_channel): def main(): total_start_time = time.perf_counter() args = parse_args() - cap = pyshark.FileCapture( - args.pcapng, - use_json=True, - include_raw=False, - keep_packets=False - ) - # Checking if the pcapng file is valid - count = 0 - try: - for packet in cap: - count += 1 - if count > 0: - break - except Exception as e: - print(f"[!] Error reading pcapng file: {e}") - return - finally: - cap.close() + # Step 1: Build indexed capture ONCE + print(f"[+] Loading and indexing capture file: {args.pcapng}") + indexed_cap = IndexedCapture(args.pcapng) + # Step 2: Process CSV rows, original_fields = read_csv_input(args.csv) fieldnames = original_fields + [ 'ClientsOnAP', 'ClientsOnChannel', 'APsOnChannel', @@ -155,7 +141,6 @@ def main(): ] enriched_rows = [] - ssid_summary = None all_ssid_summaries = [] for row in rows: @@ -170,17 +155,19 @@ def main(): start_time = time.perf_counter() - result = analyze_pcap(args.pcapng, tstart, tend, ap_bssid, ap_channel) + # STEP 3: Query preloaded capture instead of reloading PCAP + result = indexed_cap.query_metrics(tstart, tend, ap_bssid, ap_channel) + ( clients_ap, clients_chan, aps_chan, avg_signal, strongest_signal, unlinked, cisco_avg_reported_clients, cisco_max_reported_clients, num_bssids, average_signal, max_ssid_signal, num_channels_ssid, - ssid_summary, packet_count + packet_count ) = result elapsed_time = time.perf_counter() - start_time - print(f"[+] Analyzed {ap_bssid} in {elapsed_time:.2f} seconds") + print(f"[+] Queried {ap_bssid} in {elapsed_time:.2f} seconds") row.update({ 'ClientsOnAP': clients_ap, @@ -199,13 +186,11 @@ def main(): }) enriched_rows.append(row) - ssid_summary = result[-2] - all_ssid_summaries.append(ssid_summary) - + # Step 4: Save outputs write_enriched_csv(args.output, fieldnames, enriched_rows) - merged_ssid_summary = merge_ssid_summaries(all_ssid_summaries) - write_ssid_sidecar(args.output, merged_ssid_summary) + # NOTE: SSID summary generation could ALSO come from IndexedCapture later... + # but for now, use your merge_ssid_summaries method if needed. print(f"[+] Enrichment complete: {args.output}") diff --git a/enrichment/indexed_capture.py b/enrichment/indexed_capture.py new file mode 100644 index 0000000..be61eb3 --- /dev/null +++ b/enrichment/indexed_capture.py @@ -0,0 +1,223 @@ +from collections import defaultdict +from statistics import mean +from enrichment.utils import get_channel_from_freq + +class IndexedCapture: + def __init__(self, pcap_path): + self.pcap_path = pcap_path + self.all_packets = [] + self.time_index = [] # List of (timestamp, packet) + self.bssid_to_ssid = {} # BSSID → SSID + self.ssid_to_bssids = defaultdict(set) # SSID → {BSSIDs} + self.ssid_hidden_status = {} # SSID → hidden True/False + self.ssid_encryption_status = {} # SSID → open/encrypted + self.ssid_signals = defaultdict(list) # SSID → list of dBm values + self.cisco_ssid_clients = defaultdict(list) # SSID → client counts + self.cisco_reported_clients = [] # list of all reported counts + self.ssid_packet_counts = defaultdict(int) # SSID → number of packets + self.ssid_clients = defaultdict(set) # SSID → MAC addresses + self.channel_to_aps = defaultdict(set) # Channel → BSSID + self.channel_to_clients = defaultdict(set) # Channel → client MACs + self.packet_signals_by_channel = defaultdict(list) # Channel → dBm signals + self.packet_timestamps = [] # List of packet timestamps (for quick windowing) + + self._load_and_index() + + def _load_and_index(self): + import pyshark + + capture = pyshark.FileCapture( + self.pcap_path, + use_json=True, + include_raw=False, + keep_packets=False, + display_filter="(wlan.fc.type_subtype == 8 || wlan.fc.type_subtype == 5 || wlan.fc.type == 2) && (wlan.bssid || wlan.sa || wlan.da)" + ) + + for packet in capture: + try: + ts = float(packet.frame_info.time_epoch) + self.time_index.append((ts, packet)) + self.packet_timestamps.append(ts) + + if 'radiotap' not in packet or 'wlan' not in packet: + continue + + radio = packet.radiotap + wlan = packet.wlan + + if not hasattr(radio, 'channel') or not hasattr(radio.channel, 'freq'): + continue + + freq = int(radio.channel.freq) + channel = get_channel_from_freq(freq) + + subtype = int(getattr(wlan, 'type_subtype', '0'), 16) + + # Management Frames: Beacon / Probe Response + if subtype in (5, 8): + self._process_management_frame(packet, wlan, radio, channel) + + # Track clients on this channel + sa = getattr(wlan, 'sa', '').lower() + da = getattr(wlan, 'da', '').lower() + bssid = getattr(wlan, 'bssid', '').lower() + + for mac in (sa, da): + if mac and mac != 'ff:ff:ff:ff:ff:ff': + self.channel_to_clients[channel].add(mac) + + except Exception as e: + continue + + capture.close() + + def _process_management_frame(self, packet, wlan, radio, channel): + try: + mgt = packet.get_multiple_layers('wlan.mgt')[0] + tags = mgt._all_fields.get('wlan.tagged.all', {}).get('wlan.tag', []) + + ssid = None + hidden_ssid = False + privacy_bit = mgt._all_fields.get('wlan_mgt.fixed.capabilities.privacy') + is_open = (str(privacy_bit) != '1') + + for tag in tags: + if tag.get('wlan.tag.number') == '0': + raw_ssid = tag.get('wlan.ssid', '') + if not raw_ssid: + hidden_ssid = True + ssid = '' + else: + try: + ssid_bytes = bytes.fromhex(raw_ssid.replace(':', '')) + ssid = ssid_bytes.decode('utf-8', errors='replace') + except Exception: + ssid = None + + if tag.get('wlan.tag.number') == '133': + try: + num_clients = int(tag.get('wlan.cisco.ccx1.clients')) + if ssid: + self.cisco_ssid_clients[ssid].append(num_clients) + self.cisco_reported_clients.append(num_clients) + except (TypeError, ValueError): + pass + + if not ssid: + return + + self.ssid_hidden_status[ssid] = hidden_ssid + self.ssid_encryption_status.setdefault(ssid, is_open) + self.ssid_packet_counts[ssid] += 1 + + bssid = getattr(wlan, 'bssid', '').lower() + if not bssid or bssid == 'ff:ff:ff:ff:ff:ff': + return + + self.bssid_to_ssid[bssid] = ssid + self.ssid_to_bssids[ssid].add(bssid) + + signal = getattr(radio, 'dbm_antsignal', None) + if signal: + self.ssid_signals[ssid].append(int(signal)) + + self.channel_to_aps[channel].add(bssid) + if signal: + self.packet_signals_by_channel[channel].append(int(signal)) + + except Exception as e: + pass + + def get_packets_in_time_range(self, start_ts, end_ts): + # This is fast because packet timestamps were recorded at load + return [ + packet for ts, packet in self.time_index + if start_ts <= ts <= end_ts + ] + + def query_metrics(self, start_ts, end_ts, ap_bssid, ap_channel): + packets = self.get_packets_in_time_range(start_ts, end_ts) + + # Use indexed data instead of recalculating + clients_on_ap = self._count_clients_on_ap(packets, ap_bssid) + clients_on_channel = len(self.channel_to_clients.get(ap_channel, [])) + aps_on_channel = len(self.channel_to_aps.get(ap_channel, [])) + + avg_ap_signal, max_ap_signal = self._calc_signal_stats(ap_channel) + unlinked_devices = self._count_unlinked_devices(packets, ap_channel) + + our_ssid = self.bssid_to_ssid.get(ap_bssid) + num_bssids = len(self.ssid_to_bssids.get(our_ssid, [])) if our_ssid else 0 + avg_ssid_signal = mean(self.ssid_signals.get(our_ssid, [])) if our_ssid else 0 + max_ssid_signal = max(self.ssid_signals.get(our_ssid, [])) if our_ssid else 0 + num_channels_ssid = len(self.ssid_to_bssids.get(our_ssid, [])) if our_ssid else 0 + packet_count = len(packets) + + return ( + clients_on_ap, clients_on_channel, aps_on_channel, + avg_ap_signal, max_ap_signal, unlinked_devices, + self._cisco_avg_clients(our_ssid), self._cisco_max_clients(our_ssid), + num_bssids, avg_ssid_signal, max_ssid_signal, + num_channels_ssid, packet_count + ) + + def _count_clients_on_ap(self, packets, ap_bssid): + clients = defaultdict(int) + ap_bssid = ap_bssid.lower() + + for packet in packets: + try: + if not hasattr(packet, "wlan"): + continue + sa = getattr(packet.wlan, "sa", '').lower() + da = getattr(packet.wlan, "da", '').lower() + bssid = getattr(packet.wlan, "bssid", '').lower() + + if bssid == ap_bssid or sa == ap_bssid or da == ap_bssid: + if sa == ap_bssid and da and da != ap_bssid and not da.startswith("ff:ff:ff:ff:ff:ff"): + clients[da] += 1 + elif sa and sa != ap_bssid and not sa.startswith("ff:ff:ff:ff:ff:ff"): + clients[sa] += 1 + except AttributeError: + continue + + return len([mac for mac, count in clients.items() if count > 3]) + + def _calc_signal_stats(self, ap_channel): + signals = self.packet_signals_by_channel.get(ap_channel, []) + return (mean(signals), max(signals)) if signals else (0, 0) + + def _count_unlinked_devices(self, packets, ap_channel): + aps = self.channel_to_aps.get(ap_channel, set()) + ghost_clients = set() + + for packet in packets: + try: + if 'radiotap' not in packet or 'wlan' not in packet: + continue + + radio = packet.radiotap + wlan = packet.wlan + + sa = getattr(wlan, 'sa', '').lower() + da = getattr(wlan, 'da', '').lower() + + for mac in (sa, da): + if mac and mac != 'ff:ff:ff:ff:ff:ff' and mac not in aps: + ghost_clients.add(mac) + + except Exception: + continue + + return len(ghost_clients) + + def _cisco_avg_clients(self, ssid): + if ssid in self.cisco_ssid_clients: + return round(mean(self.cisco_ssid_clients[ssid]), 2) + return 0 + + def _cisco_max_clients(self, ssid): + if ssid in self.cisco_ssid_clients: + return max(self.cisco_ssid_clients[ssid]) + return 0