# Copyright 2013 Canonical Ltd.  This software is licensed under the
# GNU Affero General Public License version 3 (see the file LICENSE).

"""Tests for `maastest.main`."""

from __future__ import (
    absolute_import,
    print_function,
    unicode_literals,
    )

__metaclass__ = type
__all__ = []

import logging
import os
from random import randint

from maastest import (
    detect_dhcp,
    main,
    utils,
    )
import mock
import testtools


def make_interface_name():
    """Generate an arbitrary network interface name."""
    return 'itf-%d' % randint(0, 100)


class TestTestMAAS(testtools.TestCase):

    def patch_logging(self):
        """Shut up logging for the duration of this test."""
        # Quick and dirty.  Replace with something nicer if appropriate.
        for logging_function in ['debug', 'info', 'error']:
            self.patch(logging, logging_function, mock.MagicMock())

    def pretend_to_be_user(self, uid):
        """Make the am-I-root check think we are user no. `uid`.

        Pass zero for root.
        """
        self.patch(os, 'geteuid', mock.MagicMock(return_value=uid))

    def patch_kvm_ok(self, ok=True):
        """Patch `check_kvm_ok` to return the given value."""
        self.patch(utils, 'check_kvm_ok', mock.MagicMock(return_value=ok))

    def patch_virtualization_type(self, virt_type=None):
        """Patch `virtualization_type` to return the given value.

        The default, type `None`, always says this is a physical system.
        """
        self.patch(
            utils, 'virtualization_type',
            mock.MagicMock(return_value=virt_type))

    def make_args(self, interface, interactive=True):
        """Create a fake arguments object with the given parameters."""
        args = mock.MagicMock()
        args.interface = interface
        args.interactive = interactive
        return args

    def test_refuses_to_run_if_interface_detects_dhcp_server(self):
        self.patch_logging()
        self.pretend_to_be_user(0)
        self.patch_kvm_ok()
        self.patch_virtualization_type()

        interface = make_interface_name()
        server = '127.55.33.11'
        self.patch(
            detect_dhcp, 'probe_dhcp', mock.MagicMock(return_value={server}))

        return_value = main.main(self.make_args(interface))

        self.assertEqual(
            main.RETURN_CODES.UNEXPECTED_DHCP_SERVER, return_value)

    def test_requires_sudo(self):
        self.pretend_to_be_user(randint(100, 200))
        self.patch_kvm_ok()
        self.patch_virtualization_type()
        interface = make_interface_name()

        return_value = main.main([interface, '--interactive'])

        self.assertEqual(main.RETURN_CODES.NOT_ROOT, return_value)


def make_proxy_url():
    return 'http://%d.example.com:%d' % (randint(1, 999999), randint(1, 65535))


class TestSetUpProxy(testtools.TestCase):

    def setUp(self):
        super(TestSetUpProxy, self).setUp()
        self.patch(main, 'make_local_proxy_fixture', mock.MagicMock())
        self.patch(os, 'putenv', mock.MagicMock())

    def make_args(self, http_proxy=None, disable_cache=False):
        """Create a fake arguments object with the given parameters."""
        args = mock.MagicMock()
        args.http_proxy = http_proxy
        args.disable_cache = disable_cache
        return args

    def test_returns_empty_if_caching_disabled(self):
        self.assertEqual(
            ('', None),
            main.set_up_proxy(self.make_args(disable_cache=True)))

    def test_returns_existing_proxy_if_set(self):
        proxy = make_proxy_url()
        self.assertEqual(
            (proxy, None),
            main.set_up_proxy(self.make_args(http_proxy=proxy)))

    def test_creates_proxy_if_appropriate(self):
        proxy_url, proxy_fixture = main.set_up_proxy(self.make_args())
        self.assertIsNotNone(proxy_fixture)
        self.addCleanup(proxy_fixture.cleanUp)
        self.assertEqual(proxy_fixture.get_url(), proxy_url)

    def test_cleans_up_proxy_on_failure(self):
        class DeliberateFailure(Exception):
            """Deliberately induced error for testing."""

        self.patch(
            main.make_local_proxy_fixture.return_value, 'get_url',
            mock.MagicMock(side_effect=DeliberateFailure))

        self.assertRaises(
            DeliberateFailure,
            main.set_up_proxy, self.make_args())

        self.assertEqual(
            [mock.call()],
            main.make_local_proxy_fixture.return_value.cleanUp.mock_calls)

    def test_sets_env_to_existing_proxy(self):
        proxy = make_proxy_url()
        main.set_up_proxy(self.make_args(http_proxy=proxy))
        self.assertItemsEqual(
            [
                mock.call('http_proxy', proxy),
                mock.call('https_proxy', proxy),
            ],
            os.putenv.mock_calls)

    def test_sets_env_to_created_proxy(self):
        proxy_url, _ = main.set_up_proxy(self.make_args())
        self.assertItemsEqual(
            [
                mock.call('http_proxy', proxy_url),
                mock.call('https_proxy', proxy_url),
            ],
            os.putenv.mock_calls)

    def test_leaves_env_unchanged_if_caching_disabled(self):
        main.set_up_proxy(self.make_args(disable_cache=True))
        self.assertEqual([], os.putenv.mock_calls)


class TestCheckAgainstDHCPServersFromVM(testtools.TestCase):

    def make_fake_vm(self):
        """Create a minimal fake virtual-machine fixture."""
        fake = mock.MagicMock()
        fake.direct_ip = '192.168.%d.%d' % (randint(0, 255), randint(1, 254))
        return fake

    def patch_to_find_no_servers(self, fake):
        """Patch a fake VM fixture to find no DHCP servers."""
        self.patch(
            fake, 'run_command', mock.MagicMock(return_value=(0, '', '')))
        return fake

    def patch_to_find_servers(self, fake, servers):
        """Patch a fake VM fixture to find the given DHCP servers."""
        return_code = len(servers)
        output = "DHCP servers detected: %s" % ', '.join(servers)
        self.patch(
            fake, 'run_command',
            mock.MagicMock(return_value=(return_code, output, '')))

    def patch_has_maas_probe_dhcp(self, present=True):
        """Patch `has_maas_probe_dhcp` to return the given answer."""
        self.patch(
            main, 'has_maas_probe_dhcp', mock.MagicMock(return_value=present))

    def make_dhcp_server_ip(self):
        """Return an IP address for a fictional unexpected DHCP server."""
        return '10.%d.%d.%d' % (
            randint(0, 254),
            randint(0, 254),
            randint(1, 254),
            )

    def test_passes_if_no_dhcp_servers_detected(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_no_servers(fake_vm)
        main.check_against_dhcp_servers_from_vm(fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))

    def test_fails_if_dhcp_server_detected(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_servers(fake_vm, [self.make_dhcp_server_ip()])
        exception = self.assertRaises(
            main.ProgramFailure,
            main.check_against_dhcp_servers_from_vm, fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))
        self.assertEqual(
            main.RETURN_CODES.UNEXPECTED_DHCP_SERVER, exception.return_code)

    def test_fails_if_multiple_dhcp_servers_detected(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_servers(
            fake_vm, [self.make_dhcp_server_ip() for _ in range(3)])
        exception = self.assertRaises(
            main.ProgramFailure,
            main.check_against_dhcp_servers_from_vm, fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))
        self.assertEqual(
            main.RETURN_CODES.UNEXPECTED_DHCP_SERVER, exception.return_code)

    def test_passes_if_only_maas_dhcp_server_detected(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_servers(fake_vm, [fake_vm.direct_ip])
        main.check_against_dhcp_servers_from_vm(fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))

    def test_fails_if_maas_and_other_dhcp_server_detected(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_servers(
            fake_vm, [fake_vm.direct_ip, self.make_dhcp_server_ip()])
        exception = self.assertRaises(
            main.ProgramFailure,
            main.check_against_dhcp_servers_from_vm, fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))
        self.assertEqual(
            main.RETURN_CODES.UNEXPECTED_DHCP_SERVER, exception.return_code)

    def test_propagates_other_failure(self):
        self.patch_has_maas_probe_dhcp(present=True)
        fake_vm = mock.MagicMock()
        fake_vm.run_command = mock.MagicMock(return_value=(1, '', 'Kaboom'))
        exception = self.assertRaises(
            Exception,
            main.check_against_dhcp_servers_from_vm, fake_vm)
        self.assertEqual(1, len(fake_vm.run_command.mock_calls))
        self.assertIn(
            "Call to maas-probe-dhcp failed in virtual machine",
            repr(exception))
        self.assertIn('Kaboom', repr(exception))

    def test_skips_if_maas_probe_dhcp_not_present(self):
        self.patch_has_maas_probe_dhcp(present=False)
        fake_vm = self.make_fake_vm()
        self.patch_to_find_servers(fake_vm, [self.make_dhcp_server_ip()])
        self.assertEqual(0, len(fake_vm.run_command.mock_calls))
