#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

from winswitch.util.simple_logger import Logger
logger = Logger("conch_server_channel")
debug_import = logger.get_debug_import()

debug_import("consts")
from winswitch.consts import NOTIFY_ERROR
debug_import("server_config")
from winswitch.objects.server_config import ServerConfig
debug_import("server_line_connection")
from winswitch.client.server_line_connection import ServerLineConnection
debug_import("conch_channels")
from winswitch.net.conch_channels import ExecLineChannel
debug_import("conch_util")
from winswitch.net.conch_util import ConchFactory, ConchTransport
debug_import("config")
from winswitch.util.config import modify_server_config
debug_import("common")
from winswitch.util.common import hash_text
debug_import("main_loop")
from winswitch.util.main_loop import callLater



class ConchLinkTransport(ConchTransport):
	"""
	And provide stop()
	"""

	def stop(self, retry=False, message=None):
		self.slog(None, retry, message)
		self.transport.loseConnection()

class ConchLinkFactory(ConchFactory):
	"""
	Wraps ConchFactory so we can pass ServerConfig objects in.
	Also records connection state and calls ready() when authentication is complete.
	"""
	def __init__(self, server_link):
		self.server_link = server_link
		server = server_link.server
		self.link_channel = None
		ConchFactory.__init__(self, server.username, server.password, server.host, server.port,
								server.ssh_pub_keyfile, server.ssh_keyfile, server.ssh_keyfile_passphrase, server.ssh_hostkey_fingerprint)
		self.transport = ConchLinkTransport
		self.server_name = server.name
		self.dialog_util = server_link.dialog_util
		self.channel_constructors.append(self.ready)

	def set_fingerprint(self, new_fingerprint):
		self.sdebug(None, new_fingerprint)
		self.server_link.server.ssh_hostkey_fingerprint = new_fingerprint
		modify_server_config(self.server_link.server, ["ssh_hostkey_fingerprint"])

	def set_password(self, new_password):
		self.sdebug(None, hash_text(new_password))
		self.server_link.server.set_password(new_password)
		modify_server_config(self.server_link.server, ["encrypted_password"])

	def set_passphrase(self, new_passphrase):
		self.sdebug(None, hash_text(new_passphrase))
		self.server_link.server.set_ssh_keyfile_passphrase(new_passphrase)
		modify_server_config(self.server_link.server, ["encrypted_ssh_keyfile_passphrase"])

	def ready(self, connection):
		"""
		Starts the link_channel and hooks it with ServerLineConnection
		"""
		self.slog(None, connection)
		self.server_link.conch_connection = connection
		def link_ready(*args):
			self.slog(None, *args)
			client = ServerLineConnection(self.server_link, self.link_channel.writeLine, self.link_channel.stop, self.link_channel.is_connected)
			self.server_link.set_client(client)
			self.server_link.set_server_status(self, ServerConfig.STATUS_CONNECTED)
			self.link_channel.line_callback = client.handle_command
			client.connectionMade()
		def link_closed(*args):
			self.slog(None, *args)
			if self.server_link.client_factory==self:
				self.server_link.close_connections()
		self.link_channel = self.start_link_channel(connection, link_ready, link_closed)
		callLater(self.server_link.server.timeout, self.server_link.check_connection_timeout)

	def start_link_channel(self, connection, link_ready_cb):
		raise Exception("subclasses must override this method and return an ExecLineChannel instance!")

	def clientConnectionLost(self, connector, reason):
		ConchFactory.clientConnectionLost(self, connector, reason)
		self.disconnected()

	def clientConnectionFailed(self, connector, reason):
		ConchFactory.clientConnectionFailed(self, connector, reason)
		self.disconnected()
		server = self.server_link.server
		message = "Failed to connect to the SSH server at %s:%s," % (server.host, server.port)
		if not server.dynamic:
			message += "\nplease ensure that the hostname and port number are correct,"
		message += "\nIs the SSH server running on that host?\nIs a firewall blocking access to it?"
		self.server_link.notify("Cannot connect to SSH server on %s" % server.get_display_name(), message,
							notification_type=NOTIFY_ERROR)
	
	def disconnected(self):
		self.server_link.conch_connection = None
		self.server_link.set_server_status(self, ServerConfig.STATUS_DISCONNECTED)
		self.server_link.server.touch()
		self.serror()

class ConchFixedPortLinkFactory(ConchLinkFactory):
	"""
	This factory is used when we tunnel to a host and we already know what port the server is on.
	"""
	def __init__(self, server_link):
		ConchLinkFactory.__init__(self, server_link)

	def start_link_channel(self, connection, link_ready_cb, link_closed_cb):
		return	ExecLineChannel("winswitch_stdio_tcp %s %s" % (self.server_link.server.command_host, self.server_link.server.command_port), link_ready_cb, None, link_closed_cb, connection)

class ConchStdioRemoteClientLinkFactory(ConchLinkFactory):
	"""
	This factory is used when we tunnel to a host without knowing the target port, so we connect via winswitch_stdio
	"""
	def __init__(self, server_link):
		ConchLinkFactory.__init__(self, server_link)

	def start_link_channel(self, connection, link_ready_cb, link_closed_cb):
		return	ExecLineChannel("winswitch_stdio_socket", link_ready_cb, None, link_closed_cb, connection)
