# Examples used in the documentation

import unittest
from test_commandline import confirm, run

class TestBenzotriazole(unittest.TestCase):
    def test_benzotriazole(self):
        confirm("../sample_files/benzotriazole.sdf", 9, 10)
    def test_select(self):
        confirm("--select 1-25 ../sample_files/benzotriazole.sdf", 16, 17)

class Test_ar_clustered_3D_MM_3(unittest.TestCase):
    def test_default(self):
        confirm("../sample_files/ar_clustered_3D_MM_3.sdf", 14, 13)

    def test_topology(self):
        confirm("../sample_files/ar_clustered_3D_MM_3.sdf --compare topology", 17, 20)

    def test_elements(self):
        confirm("../sample_files/ar_clustered_3D_MM_3.sdf --compare elements", 17, 20)

class TestSMARTSOutputs(unittest.TestCase):
    def test_normal(self):
        output, ignore = run("../sample_files/p38_clustered_3D_MM_9.sdf")
        assert "12 atoms 11 bonds (complete search)" in output, line
    def test_no_mcs(self):
        output, ignore = run("../sample_files/egfr_clustered_3D_MM_2.sdf")
        assert "No MCS found" in output, output
    def test_timeout(self):
        output, ignore = run("../tests/lengthy.smi --timeout 1")
        assert "(timed out)" in output, output

class TestFragmentSmilesOutput(unittest.TestCase):
    def test_output_one(self):
        output, ignore = run("../sample_files/p38_clustered_3D_MM_9.sdf --output-format fragment-smiles")
        fields = output.split()
        assert len(fields) == 2, (output, fields)
        assert "c-c" in fields[0], output
        assert fields[1] == "ZINC03832128", output

    def test_output_all(self):
        output, ignore = run("../sample_files/p38_clustered_3D_MM_9.sdf --output-format fragment-smiles --output-all")
        ids = []
        for line in output.splitlines():
            fields = line.split()
            assert "c-c" in fields[0], line
            ids.append(fields[1])
        expected_ids = """\
ZINC03832128
ZINC03815736
ZINC03832064
ZINC03815693
ZINC03815735
ZINC03815689
ZINC03815771
ZINC04617902
ZINC04617926
ZINC03815704
ZINC03815752
ZINC03815731
ZINC00020320
ZINC03815775
ZINC04617907
ZINC03815786
ZINC04617909
ZINC03832115
ZINC03815724
ZINC03832140
ZINC03832156
ZINC03832165
ZINC03815680
ZINC03815759
ZINC03815725
ZINC03832073
ZINC03832054
ZINC03815776
ZINC03815705
ZINC03815700
ZINC03995359
ZINC03794516
ZINC03815770
ZINC04617912
ZINC04617922
ZINC03815741
ZINC04617919
ZINC03794511
ZINC00833184
ZINC03815757
ZINC04617900
ZINC04617916
ZINC03815756
ZINC04617911
ZINC04617915
ZINC04617908
ZINC03815758
ZINC03815709
ZINC03815779
ZINC03815721
ZINC03832090
ZINC03815747
ZINC03815727
ZINC03815746
ZINC03815710
ZINC04617917
ZINC03815615
ZINC03815772
ZINC04617913
ZINC03815761
ZINC03815617""".splitlines()
        for i, (id, expected_id) in enumerate(zip(ids, expected_ids)):
            self.assertEqual(id, expected_id, "%r != %r (%d)" % (id, expected_id, i))


class TestAtomCompare(unittest.TestCase):
    def test_default_setting(self):
        confirm("../sample_files/na_clustered_3D_MM_1.sdf", 4, 3)

    def test_explicit_default_setting(self):
        confirm("--atom-compare elements ../sample_files/na_clustered_3D_MM_1.sdf", 4, 3)

    def test_using_any(self):
        confirm("--atom-compare any ../sample_files/na_clustered_3D_MM_1.sdf", 5, 4)


na_clustered_3D_MM_1_ids = """\
ZINC03581099
ZINC03581100
ZINC03833958
ZINC03581810
ZINC04134481
ZINC03833955
ZINC03833968
ZINC03833957
ZINC03833960
ZINC04134482
ZINC03833959
ZINC04134483
ZINC03581157
ZINC03581156
ZINC04134492
ZINC04134484
ZINC04134485
ZINC04134486
ZINC03929509
ZINC04134487
ZINC03833961
ZINC04134501
ZINC02047891
ZINC01703370
ZINC03833967
ZINC04134499
ZINC04134498
ZINC04134497
ZINC04134493
ZINC04134490
ZINC04134489
ZINC04134488
ZINC04646290
ZINC04134494
ZINC04134495
ZINC04134496
ZINC03833956
ZINC04134500
ZINC04646291
ZINC04646292
ZINC04646293
ZINC04646294""".splitlines()

class TestSDFOutput(unittest.TestCase):
    def test_fragment_sdf(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format fragment-sdf")
        lines = output.splitlines()
        self.assertEqual(lines[0], "ZINC03581099")
        self.assertEqual(lines[3][:6], "  4  3") # atom and bond counts

    def test_compete_sdf(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format complete-sdf --save-counts-tag mcs-counts")
        lines = output.splitlines()
        self.assertEqual(lines[0], "ZINC03581099")
        self.assertEqual(lines[3][:6], " 14 14") # atom and bond counts

    @staticmethod
    def _get_sdf_info(output):
        lines = output.splitlines()
        ids = []
        sizes = []
        lineno = 0
        for line in lines:
            lineno += 1
            if lineno == 1:
                ids.append(line)
            if lineno == 4:
                sizes.append(line[:6])
            if line == "$$$$":
                lineno = 0

        return ids, sizes

    def test_fragment_sdf_all(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format fragment-sdf --output-all --save-counts-tag my-mcs-counts")
        ids, sizes = self._get_sdf_info(output)

        self.assertEqual(len(ids), len(na_clustered_3D_MM_1_ids))
        for i, (id, expected_id) in enumerate(zip(ids, na_clustered_3D_MM_1_ids)):
            self.assertEqual(id, expected_id, "%r != %r (%d)" % (id, expected_id, i))

        x = set(sizes)
        self.assertEqual(len(x), 1, x)
        self.assertEqual(list(x)[0], "  4  3")
        self.assertEqual(output.count("<my-mcs-counts>\n1 4 3"), len(ids))

    def test_complete_sdf_all(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format complete-sdf --output-all --save-counts-tag mcs-counts")
        ids, sizes = self._get_sdf_info(output)

        self.assertEqual(len(ids), len(na_clustered_3D_MM_1_ids))
        for i, (id, expected_id) in enumerate(zip(ids, na_clustered_3D_MM_1_ids)):
            self.assertEqual(id, expected_id, "%r != %r (%d)" % (id, expected_id, i))

        self.assertNotEqual(len(set(sizes)), 1)
        self.assertEqual(output.count("<mcs-counts>"), len(ids))

    def test_fragment_tags(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format fragment-sdf "
                             "--save-counts-tag counts --save-smarts-tag smarts --save-smiles-tag smiles"
                             )
        assert "<counts>\n1 4 3\n" in output, output
        # These two are very hard to test given that the actual content might change
        # depending on the RDKit and fmcs version.
        assert "<smarts>\n[" in output, output
        assert ("<smiles>\nCCNc\n" in output or "<smiles>\ncNCC\n" in output), output

    def test_complete_tags(self):
        output, ignore = run("../sample_files/na_clustered_3D_MM_1.sdf --output-format complete-sdf "
                             "--save-counts-tag counts --save-smarts-tag smarts --save-smiles-tag smiles "
                             "--save-atom-indices-tag blah"
                             )
        assert "<counts>\n1 4 3\n" in output, output
        # These two are very hard to test given that the actual content might change
        # depending on the RDKit and fmcs version.
        assert "<smarts>\n[" in output, output
        assert ("<smiles>\nCCNc\n" in output or "<smiles>\ncNCC\n" in output), output
        assert "<blah>\n0 1 3 4\n" in output, output

class TestModifyBondComparisons(unittest.TestCase):
    def test_default(self):
        confirm("../sample_files/pdgfrb_clustered_3D_MM_4.sdf", 11, 11)
        
    def test_complete_rings_only(self):
        confirm("../sample_files/pdgfrb_clustered_3D_MM_4.sdf --complete-rings-only", 7, 7)


class TestVerbosity(unittest.TestCase):
    def test_verbose(self):
        output, errout = run("--verbose --timeout 1 ../tests/lengthy.smi")
        assert "Loaded 2 structures" in errout, errout
        assert "3 atoms 2 bonds" not in errout, errout
        assert "subgraphs enumerated" not in errout, errout
        assert "Total time" in errout, errout
        assert "seconds: " in errout, errout

    def test_verbose_verbose(self):
        output, errout = run("--verbose -v --timeout 1 ../tests/lengthy.smi")
        assert "Loaded 2 structures" in errout, errout
        assert "3 atoms 2 bonds" in errout, errout
        assert "subgraphs enumerated" in errout, errout
        assert "Total time" in errout, errout
        assert "seconds: " in errout, errout

class TestSmilesIsotopes(unittest.TestCase):
    def test_smsd_arom(self):
        confirm("../sample_files/smsd_arom.smi", 6, 6)
        
    def test_smsd_arom_isotopes(self):
        # [9*]-[6*]-1-[6*]-[6*]-[6*]-[6*]-[6*]-1 7 atoms 7 bonds (complete search)
        output, errout = run("--atom-compare isotopes ../tests/smsd_arom_isotopes.smi")
        assert "7 atoms 7 bonds (complete search)" in output
        # Make sure it's giving me a halogen
        assert "[9*]" in output

    def test_smsd_arom_isotopes_smiles_fragments(self):
        output, errout = run("--atom-compare isotopes --output-all --output-format fragment-smiles ../tests/smsd_arom_isotopes.smi ")
        assert "BrC" in output, output
        assert "FC" in output, output
        assert len(output.splitlines()) == 5, output


class TestSDFAtomClasses(unittest.TestCase):
    def test_atom_classes(self):
        output, errout = run("--atom-class-tag atom_classes ../tests/smsd_arom_atom_classes.sdf")
        assert "7 atoms 7 bonds (complete search)" in output
        # Make sure it's giving me a halogen
        assert "[9*]" in output

    def test_for_modified_class_tag_in_fragment(self):
        output, errout = run("--atom-class-tag atom_classes ../tests/smsd_arom_atom_classes.sdf "
                             "--output-format fragment-sdf --save-smiles-tag SPAM")

        #assert "<atom_classes>\n9 6 6 6 6 6 6\n" in output, output
        assert "<atom_classes>\n6 6 6 6 6 6 9\n" in output, output

        assert "<SPAM>\nClC1CCCCC1\n" in output, output

    def test_for_indices_in_complete_sdf(self):
        output, errout = run("--atom-class-tag atom_classes ../tests/smsd_arom_atom_classes.sdf "
                             "--output-format complete-sdf --save-smiles-tag SPAM "
                             "--save-atom-indices-tag INDICES")

        assert "<atom_classes>\n6 6 9 6 6 9 6 6\n" in output, output
        #assert "<INDICES>\n2 1 0 7 6 4 3\n" in output, output
        assert "<INDICES>\n0 1 3 4 6 7 2\n" in output, output
        
        assert "<SPAM>\nClC1CCCCC1\n" in output, output
        

# TODO: --output

if __name__ == "__main__":
    unittest.main()
    
