Coverage for src/sensai/geoanalytics/geopandas/graph.py: 43%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import networkx as nx 

2import numpy as np 

3import scipy 

4from itertools import combinations 

5from scipy.spatial.distance import euclidean 

6from typing import Callable, Dict 

7from shapely.geometry import MultiLineString 

8import geopandas as gp 

9 

10from .coordinates import extract_coordinates_array, GeoDataFrameWrapper 

11 

12 

13def delaunay_graph(data: np.ndarray, edge_weight: Callable[[np.ndarray, np.ndarray], float] = euclidean): 

14 """ 

15 The Delaunay triangulation of the data as networkx.Graph 

16 

17 :param data: 

18 :param edge_weight: function to compute weight given two coordinate points 

19 :return: instance of networx.Graph where the edges contain additional datapoints entries for 

20 "weight" and for constants.COORDINATE_PAIR_KEY 

21 """ 

22 tri = scipy.spatial.Delaunay(data) 

23 graph = nx.Graph() 

24 

25 for simplex in tri.simplices: 

26 for vertex_id_pair in combinations(simplex, 2): 

27 coordinate_pair = tri.points[ 

28 np.array(vertex_id_pair)] # vertex_id_pair is a tuple and needs to be cast to an array 

29 graph.add_edge(*vertex_id_pair, weight=edge_weight(*coordinate_pair)) 

30 return graph 

31 

32 

33class SpanningTree: 

34 """ 

35 Wrapper around a tree-finding algorithm that will be applied on the Delaunay graph of the datapoints 

36 """ 

37 def __init__(self, datapoints: np.ndarray, tree_finder: Callable[[nx.Graph], nx.Graph] = nx.minimum_spanning_tree): 

38 """ 

39 :param datapoints: 

40 :param tree_finder: function mapping a graph to a subgraph. The default is minimum_spanning_tree 

41 """ 

42 datapoints = extract_coordinates_array(datapoints) 

43 self.tree = tree_finder(delaunay_graph(datapoints)) 

44 edge_weights = [] 

45 self.coordinatePairs = [] 

46 for edge in self.tree.edges.data(): 

47 edge_coordinate_indices, edge_data = [edge[0], edge[1]], edge[2] 

48 edge_weights.append(edge_data["weight"]) 

49 self.coordinatePairs.append(datapoints[edge_coordinate_indices]) 

50 self.edgeWeights = np.array(edge_weights) 

51 

52 def total_weight(self): 

53 return self.edgeWeights.sum() 

54 

55 def num_edges(self): 

56 return len(self.tree.edges) 

57 

58 def mean_edge_weight(self): 

59 return self.edgeWeights.mean() 

60 

61 def summary_dict(self) -> Dict[str, float]: 

62 """ 

63 Dictionary containing coarse information about the tree 

64 """ 

65 return { 

66 "numEdges": self.num_edges(), 

67 "totalWeight": self.total_weight(), 

68 "meanEdgeWeight": self.mean_edge_weight() 

69 } 

70 

71 

72class CoordinateSpanningTree(SpanningTree, GeoDataFrameWrapper): 

73 """ 

74 Wrapper around a tree-finding algorithm that will be applied on the Delaunay graph of the coordinates. 

75 Enhances the :class:`SpanningTree` class by adding methods and validation specific to geospatial coordinates. 

76 """ 

77 def __init__(self, datapoints: np.ndarray, tree_finder: Callable[[nx.Graph], nx.Graph] = nx.minimum_spanning_tree): 

78 datapoints = extract_coordinates_array(datapoints) 

79 super().__init__(datapoints, tree_finder=tree_finder) 

80 

81 def multi_line_string(self): 

82 return MultiLineString(self.coordinatePairs) 

83 

84 def to_geodf(self, crs='epsg:3857'): 

85 """ 

86 :param crs: projection. By default pseudo-mercator 

87 :return: GeoDataFrame of length 1 with the tree as MultiLineString instance 

88 """ 

89 gdf = gp.GeoDataFrame({"geometry": [self.multi_line_string()]}) 

90 gdf.crs = crs 

91 return gdf