#!/usr/bin/env python3

from __future__ import print_function
import requests
import getpass
import json
import re

# python 2/3 compatibility code
try:
    input = raw_input
except NameError:
    pass


# The following appears to be version-sensitive
# from urllib3.exceptions import InsecureRequestWarning

# Kill the dragon!!
# If, like me, you are fed up with SSL certificate issues preventing you from connecting to your own machine,
# then set verifySsl to False here.
verifySsl=True

if not verifySsl:
	# Allegedly, this more-specific version should work, but I find that it doesn't:
	# requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
	# ... so just turn them all off:
	requests.packages.urllib3.disable_warnings()
	print("WARNING: SSL certificate verification (and warnings) disabled. This script will blindly trust whatever it connects to.")
	
datagateway_url = input("DataGateway url: ")

scigatewaySettingsJson = json.loads(requests.get(datagateway_url + "/settings.json", verify=verifySsl).text)
datagateway_download_url = re.sub(r"/main.*\.js", "", next(x for x in scigatewaySettingsJson["plugins"] if x["name"] == "datagateway-download")["src"])

# add datagateway_url back in if it's a relative URL
if datagateway_download_url[0] == "/":
	datagateway_download_url = datagateway_url + datagateway_download_url
	
settingsJson = json.loads(requests.get(datagateway_download_url + "/datagateway-download-settings.json", verify=verifySsl).text)

facility_name = settingsJson["facilityName"]

ids_urls = [settingsJson["accessMethods"][x]["idsUrl"] for x in list(settingsJson["accessMethods"])]

if len(ids_urls) > 1:
	print("Available IDS servers: ")
	for i, ids in enumerate(ids_urls):
		print(str(i+1) + ": " + ids)
	try:
		option_number = int(input("Choose a number to select that IDS server: "))
	except ValueError:
		# this will trigger the Invalid Input response below
		option_number = -1
	option_number = option_number - 1
	if option_number not in range(len(ids_urls)):
		print("Invalid input,  selecting default IDS")
		ids_url = settingsJson["idsUrl"]
	else:
		ids_url = ids_urls[option_number]
else:
	ids_url = settingsJson["idsUrl"]

print("IDS url: " + ids_url)

topcat_url = settingsJson["downloadApiUrl"]

icat_url = requests.get(ids_url + "/getIcatUrl", verify=verifySsl).text + "/icat"

print("ICAT url: " + icat_url)

icat_properties = json.loads(requests.get(icat_url + "/properties", verify=verifySsl).text)


authentication_plugins = []
for authenticator in icat_properties["authenticators"]:
	authentication_plugins.append(authenticator["mnemonic"]) 

if len(authentication_plugins) > 1:
	print("Available authentication plugins: ")
	for authentication_plugin in authentication_plugins:
		print(" * " + authentication_plugin)

	authentication_plugin = input("Which authentication plugin?: ")
else:
	authentication_plugin = authentication_plugins[0]

username = input("Username: ")
password = getpass.getpass("Password: ")

auth = json.dumps({
	"plugin": authentication_plugin,
	"credentials": [
		{"username": username},
		{"password": password}
	]
})

try:
    response = requests.post(icat_url + "/session", {"json": auth}, verify=verifySsl)
    session_id = json.loads(response.text)["sessionId"]
except Exception as e:
    import sys
    sys.exit("Couldn't determine sessionId: " + str(e) + "; response: " + response.text)


def input_download_ids():
	download_ids = input("Enter one or more space separated download id(s): ")
	return download_ids.split()


def show_download():
	for download_id in input_download_ids():
		print(requests.get(topcat_url + "/admin/downloads", params={
			"facilityName": facility_name,
			"sessionId": session_id,
			"queryOffset": f"where download.id = {download_id}"
		}, verify=verifySsl).text)


def list_file_locations():
	for download_id in input_download_ids():
		output_file_name = input("Output file name (optional): ")
		download = json.loads(requests.get(topcat_url + "/admin/downloads", params={
			"facilityName": facility_name,
			"sessionId": session_id,
			"queryOffset": f"where download.id = {download_id}"
		}, verify=verifySsl).text)[0]
		download_items = download["downloadItems"]
		datafile_locations = []
		for download_item in download_items:
			if download_item["entityType"] == "investigation":
				datafile_locations.extend(json.loads(requests.get(icat_url + "/entityManager", params={
					"sessionId": session_id,
					"query": "select datafile.location from Datafile datafile, datafile.dataset as dataset, dataset.investigation as investigation where investigation.id = " + str(download_item["entityId"])
				}, verify=verifySsl).text))
			elif download_item["entityType"] == "dataset":
				datafile_locations.extend(json.loads(requests.get(icat_url + "/entityManager", params={
					"sessionId": session_id,
					"query": "select datafile.location from Datafile datafile, datafile.dataset as dataset where dataset.id = " + str(download_item["entityId"])
				}, verify=verifySsl).text))
			elif download_item["entityType"] == "datafile":
				datafile_locations.extend(json.loads(requests.get(icat_url + "/entityManager", params={
					"sessionId": session_id,
					"query": "select datafile.location from Datafile datafile where datafile.id = " + str(download_item["entityId"])
				}, verify=verifySsl).text))
		datafile_locations.sort()
		if output_file_name != "":
			file  = open(output_file_name, "w")
			for datafile_location in datafile_locations:
				file.write(datafile_location + "\n")
			file.close()
		else:
			for datafile_location in datafile_locations:
				print(datafile_location)


def prepare_download():
	for download_id in input_download_ids():
		response = requests.put(
			f"{topcat_url}/admin/download/{download_id}/prepare",
			data={
				"facilityName": facility_name,
				"sessionId": session_id,
			},
			verify=verifySsl,
		)
		print(response.status_code, response.text)


def expire_download():
	for download_id in input_download_ids():
		_expire_download(download_id=download_id)


def _expire_download(download_id):
	response = requests.put(
		f"{topcat_url}/admin/download/{download_id}/status",
		data={
			"facilityName": facility_name,
			"sessionId": session_id,
			"value": "EXPIRED"
		},
		verify=verifySsl,
	)
	print(response.status_code, response.text)


def expire_all_pending_downloads():
	query = "(download.status like 'PREPARING' or download.status like 'RESTORING') and download.isDeleted = false"

	query += " and download.facilityName = '" + facility_name + "'"

	downloads = json.loads(requests.get(topcat_url + "/admin/downloads", params={
		"facilityName": facility_name,
		"sessionId": session_id,
		"queryOffset": query
	}, verify=verifySsl).text)
	for download in downloads:
		_expire_download(download_id=download["id"])

def manage_download_types():
	download_types = list(settingsJson["accessMethods"])
	download_statuses = []
	for download_type in download_types:
		response = requests.get(topcat_url + "/user/downloadType/" + str(download_type) + "/status", params={
				"facilityName": facility_name,
				"sessionId": session_id
			}, verify=verifySsl)
		if response:
			download_status = json.loads(response.text)
			download_statuses.append(download_status)
		else:
			print("Response for '" + str(download_type) + "' not OK: " + str(response.status_code) + ", " + response.text)
			print("Treating this download type as enabled")
			download_statuses.append({"disabled":False,"message":""})
	while True:
		print("Current download type statuses:")
		for i, download_type in enumerate(download_types):
			report = download_type
			if download_statuses[i]["disabled"]:
				report += " (disabled, '" + download_statuses[i]["message"] + "')"
			else:
				report += " (enabled)"
			print(str(i+1) + ": " + report)
		print()
		# Will a Topcat administrator ever be dumb enough to input a non-number here?
		try:
			option_number = int(input("Choose a number to toggle that download type's status, or 0 to exit: "))
		except ValueError:
			# this will trigger the Invalid Input response below
			option_number = -1
		if option_number == 0: break
		option_number = option_number - 1
		if option_number not in range(len(download_types)):
			print("Invalid input")
			# We break here rather than continue, so if correct input is impossible at least we escape
			break
		download_type = download_types[option_number]
		download_status = download_statuses[option_number]
		if download_status["disabled"]:
			# Note: we don't reset the message - possible option to reuse it in future?
			response = requests.put(topcat_url + "/admin/downloadType/" + str(download_type) + "/status", data={
				"facilityName": facility_name,
				"sessionId": session_id,
				"disabled": False,
				"message": download_status["message"]
			}, verify=verifySsl)
			if response:
				print("Enabled download type " + download_type)
			else:
				print("Request failed: " + str(response.status_code) + ", " + response.text)
		else:
			if download_status["message"]:
				old_message = download_status["message"]
				print("Current message for the disabled download type is: " + old_message)
			else:
				old_message = ""
				print("Current message for the disabled download type is: system default")
			message = input("New message (blank = use current): ")
			if not len(message) > 0:
				message = old_message
			response = requests.put(topcat_url + "/admin/downloadType/" + str(download_type) + "/status", data={
				"facilityName": facility_name,
				"sessionId": session_id,
				"disabled": True,
				"message": message
			}, verify=verifySsl)
			if response:
				print("Disabled download type " + download_type)
			else:
				print("Request failed: " + str(response.status_code) + ", " + response.text)
		# Update the local copy of the download status
		# We ASSUME the request is OK
		download_status = json.loads(requests.get(topcat_url + "/user/downloadType/" + str(download_type) + "/status", params={
				"facilityName": facility_name,
				"sessionId": session_id
			}, verify=verifySsl).text)
		download_statuses[option_number] = download_status


def queue_visit():
	transport = input("Enter transport mechanism: ")
	email = input("Enter email to notify upon completion: ")
	visit_id = input("Enter visit id: ")
	data = {
		"facilityName": facility_name,
		"sessionId": session_id,
		"transport": transport,
		"email": email,
		"visitId": visit_id,
	}
	url = topcat_url + "/user/queue/visit"
	print(requests.post(url=url, data=data, verify=verifySsl).text)


def queue_files():
	transport = input("Enter transport mechanism: ")
	email = input("Enter email to notify upon completion: ")
	local_file = input("Enter path to local file containing newline delimited file locations: ")
	i = 1
	files = []
	data = {
		"facilityName": facility_name,
		"sessionId": session_id,
		"transport": transport,
		"email": email,
	}
	url = topcat_url + "/user/queue/files"
	with open(local_file) as f:
		for line in f.readlines():
			files.append(line.strip())
			if len(files) >= 10000:
				data["files"] = files
				data["fileName"] = f"{facility_name}_files_part_{i}"
				print(requests.post(url=url, data=data, verify=verifySsl).text)
				i += 1
				files = []
	
	if files:
		data["files"] = files
		print(requests.post(url=url, data=data, verify=verifySsl).text)


def get_all_queued_downloads():
	query_offset = (
		"WHERE download.isDeleted != true"
		" AND download.status = org.icatproject.topcat.domain.DownloadStatus.QUEUED"
	)
	params = {
		"facilityName": facility_name,
		"sessionId": session_id,
		"queryOffset": query_offset,
	}
	return requests.get(topcat_url + "/admin/downloads", params=params, verify=verifySsl).text


def start_download(download_id):
	data = {
		"facilityName": facility_name,
		"sessionId": session_id,
		"value": "PREPARING"
	}
	url = f"{topcat_url}/admin/download/{download_id}/status"
	response = requests.put(url=url, data=data, verify=verifySsl)
	print(response.status_code, response.text)


def show_all_queued_downloads():
	print(get_all_queued_downloads())


def start_queued_download():
	for download_id in input_download_ids():
		start_download(download_id=download_id)


def start_queued_downloads():
	text = get_all_queued_downloads()
	downloads = json.loads(text)
	for download in downloads:
		start_download(str(download["id"]))


def requeue_download():
	for download_id in input_download_ids():
		data = {
			"facilityName": facility_name,
			"sessionId": session_id,
			"value": "QUEUED"
		}
		url = f"{topcat_url}/admin/download/{download_id}/status"
		response = requests.put(url=url, data=data, verify=verifySsl)
		print(response.status_code, response.text)


def expire_all_queued_downloads():
	text = get_all_queued_downloads()
	downloads = json.loads(text)
	for download in  downloads:
		_expire_download(str(download["id"]))


while True:
	print("")
	print("What do you want to do?")
	print(" * 1: Show download(s).")
	print(" * 2: Get a list of all the file locations for download(s).")
	print(" * 3: Prepare download(s).")
	print(" * 4: Set download(s) status to 'EXPIRED'.")
	print(" * 5: Expire all pending downloads.")
	print(" * 6: Enable or disable download types.")
	print(" * 7: Queue visit.")
	print(" * 8: Queue files.")
	print(" * 9: Show all queued downloads.")
	print(" * 10: Start queued download(s).")
	print(" * 11: Start all queued downloads.")
	print(" * 12: Re-queue expired download(s).")
	print(" * 13: Expire all queued downloads.")
	print(" * 14: Exit")

	option_number = input("Enter option number: ");


	if option_number == "1":
		show_download()
	elif option_number == "2":
		list_file_locations()
	elif option_number == "3":
		prepare_download()
	elif option_number == "4":
		expire_download()
	elif option_number == "5":
		expire_all_pending_downloads()
	elif option_number == "6":
		manage_download_types()
	elif option_number == "7":
		queue_visit()
	elif option_number == "8":
		queue_files()
	elif option_number == "9":
		show_all_queued_downloads()
	elif option_number == "10":
		start_queued_download()
	elif option_number == "11":
		start_queued_downloads()
	elif option_number == "12":
		requeue_download()
	elif option_number == "13":
		expire_all_queued_downloads()
	elif option_number == "14":
		break
	else:
		print("")
		print("Unknown option")

