Source code for musicaiz.features.graphs

import matplotlib.pyplot as plt
import networkx as nx


[docs]def musa_to_graph(musa_object) -> nx.graph: """Converts a Musa object into a Graph where nodes are the notes and edges are connections between notes. A similar symbolic music graph representation was introduced in: Jeong, D., Kwon, T., Kim, Y., & Nam, J. (2019, May). Graph neural network for music score data and modeling expressive piano performance. In International Conference on Machine Learning (pp. 3060-3070). PMLR. Parameters ---------- musa_object Returns ------- _type_: _description_ """ g = nx.Graph() for i, note in enumerate(musa_object.notes): g.add_node(i, pitch=note.pitch, velocity=note.velocity, start=note.start_ticks, end=note.end_ticks) nodes = list(g.nodes(data=True)) # Add edges for i, node in enumerate(nodes): for j, next_node in enumerate(nodes): # if note has already finished it's not in the current subdivision # TODO: Check these conditions if i >= j: continue if node[1]["start"] >= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: g.add_edge(i, j, weight=5, color="violet") elif node[1]["start"] <= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: g.add_edge(i, j, weight=5, color="violet") if (j - i == 1) and (not g.has_edge(i, j)): g.add_edge(i, j, weight=5, color="red") if g.has_edge(i, i): g.remove_edge(i, i) return g
[docs]def plot_graph(graph: nx.graph, show: bool = False): """Plots a graph with matplotlib. Args: graph: nx.graph """ plt.figure(figsize=(50, 10), dpi=100) "Plots a networkx graph." pos = {i: (data["start"], data["pitch"]) for i, data in list(graph.nodes(data=True))} if nx.get_edge_attributes(graph, 'color') == {}: colors = ["violet" for _ in range(len(graph.edges()))] else: colors = nx.get_edge_attributes(graph, 'color').values() if nx.get_edge_attributes(graph, 'weight') == {}: weights = [1 for _ in range(len(graph.edges()))] else: weights = nx.get_edge_attributes(graph, 'weight').values() nx.draw( graph, pos, with_labels=True, edge_color=colors, width=list(weights), node_color='lightblue' ) if show: plt.show()