import copy
from debug import myprint
from apm_helpers.messages import Messages as msg
class LoopSolution:
 def __init__(self):
  self.is_aggregated=None
  self.loopnest=None
  self.is_too_small=None
  self.is_marked_up=None
  self.base_time=None
  self.fractional_time=None
  self.estimation=None
  self.estimation_idx=None
  self.is_profitable=None
  self.is_less_profitable_than_parent=None
  self.is_less_profitable_than_children=None
  self.speed_up=None
  self.gain=None
  self.is_potential_offload_head=None
  self.is_offloaded=None
  self.no_constraint_estimation=None
  self.max_speedup=None
 def copy(self):
  return copy.copy(self)
def find_most_profitable_heads_impl(top_rows,loopnest_indices,estimations,row2estimation_idx,host_estimations,min_required_speed_up,max_speed_up_limit,MDT,loop_filter_threshold,unroll_functions,total_time_host,whether_to_model_children,whether_check_profitability,):
 loopnest_index_list=[]
 rows=[{'idx':-1,'row':top_row,'level':1,'parent_index':None,'too_small':False,'base_time':0.0,'gain':None,'offloads':{},'per_loop_solution':{},'call_stack':(top_row,),'loopnest':None,}for top_row in top_rows]
 idx=0
 while idx<len(rows):
  curr=rows[idx]
  curr_row=curr['row']
  curr['idx']=idx
  if curr_row in host_estimations:
   curr['base_time']=host_estimations[curr_row]['total_base_time']
  else:
   curr['base_time']=None
  if curr['loopnest']is None:
   try:
    curr['loopnest']=loopnest_indices[curr_row],curr_row
    loopnest_index_list.append(idx)
    myprint("Potential offload top level: {}".format(curr_row['function_call_sites_and_loops']))
   except KeyError:
    pass
  if curr['base_time']is None:
   pass
  elif curr['base_time']<loop_filter_threshold:
   myprint('{} (id: {}) is too small, skipping offload evaluation'.format(curr_row['function_call_sites_and_loops'],curr_row['key_column']))
   curr['too_small']=True
  if curr['loopnest']is None or whether_to_model_children(curr_row):
   rows+=[{'idx':None,'row':x,'level':curr['level']+1,'parent_index':idx,'too_small':False,'gain':None,'offloads':{},'per_loop_solution':{},'call_stack':curr['call_stack']+(x,),'loopnest':curr['loopnest'],}for x in curr_row.children]
  idx+=1
 for curr in reversed(rows):
  curr_row=curr['row']
  if curr['parent_index']is not None:
   parent=rows[curr['parent_index']]
  else:
   parent=None
  gain_if_children_offloaded=curr['gain']
  myprint(msg.DEBUG_CHECKING_INTERNAL_OFFLOAD.format(curr_row['function_call_sites_and_loops'],curr_row['loop_function_id'],curr_row['key_column']))
  solution=LoopSolution()
  solution.is_too_small=curr['too_small']
  solution.is_marked_up=curr_row.is_marked_up
  solution.base_time=curr['base_time']
  solution.fractional_time=host_estimations[curr_row]['total_time']/total_time_host
  solution.is_aggregated=False
  est,gain=None,None
  curr_row_estimations=row2estimation_idx.get(curr_row)
  if not curr['too_small']and curr_row_estimations:
   est,gain,est_idx=max([x for x in curr_row_estimations if not x[0].get('relaxed')and x[1]is not None],key=lambda x:x[1],default=(None,None,None))
   solution.estimation=est
   solution.estimation_idx=est_idx
  if gain is None:
   solution.is_offloaded=False
  else:
   solution.loopnest=curr['loopnest']
   if est['time']>0.0:
    per_loop_speed_up=curr['base_time']/est['time']
    solution.speed_up=per_loop_speed_up
    solution.gain=gain
   else:
    per_loop_speed_up=1.0
   relaxed_est,_,relaxed_idx=max([x for x in curr_row_estimations if x[0].get('relaxed')and x[1]is not None],key=lambda x:x[1],default=(None,None,None))
   if relaxed_est is not None:
    solution.no_constraint_estimation=relaxed_est,relaxed_idx
    if relaxed_est['time']>0.0:
     solution.max_speedup=curr['base_time']/relaxed_est['time']
   solution.is_profitable=per_loop_speed_up>min_required_speed_up or not whether_check_profitability(curr_row)
   myprint((msg.DEBUG_OFFLOAD_IS_PROFITABLE if solution.is_profitable else msg.DEBUG_OFFLOAD_IS_NOT_PROFITABLE_OR_NOT_POSSIBLE).format(curr_row['function_call_sites_and_loops']))
   if solution.is_profitable and(gain_if_children_offloaded is None or gain>gain_if_children_offloaded):
    curr['offloads']={curr['idx']:curr_row}
    solution.is_offloaded=True
    solution.is_less_profitable_than_children=False
    if gain_if_children_offloaded is not None:
     myprint(msg.DEBUG_TOP_LEVEL_OFFLOAD_INSTEAD_OF_CHILDREN_OFFLOADS)
    for k,v in curr['per_loop_solution'].items():
     if k!=curr_row['key_column']:
      v.is_less_profitable_than_parent=True
      v.is_offloaded=False
   else:
    if solution.is_profitable and gain_if_children_offloaded is not None:
     solution.is_less_profitable_than_children=True
     myprint(msg.DEBUG_CHILDREN_OFFLOADS_INSTEAD_OF_TOP_LEVEL_OFFLOAD)
    else:
     solution.is_less_profitable_than_children=False
    solution.is_offloaded=False
    gain=gain_if_children_offloaded
   solution.is_less_profitable_than_parent=False
  if parent:
   if gain is None:
    if gain_if_children_offloaded:
     parent['gain']=gain_if_children_offloaded+(parent['gain']or 0)
     parent['offloads'].update(curr['offloads'])
   else:
    parent['gain']=gain+(parent['gain']or 0)
    parent['offloads'].update(curr['offloads'])
  curr['per_loop_solution'][curr_row]=solution
  if parent:
   parent['per_loop_solution'].update(curr['per_loop_solution'])
 loopnests={}
 offload_heads=[]
 non_offloaded_heads=[]
 solution={}
 for idx in loopnest_index_list:
  top_data=rows[idx]
  head_row=top_data['row']
  curr_offload_heads=list(top_data['offloads'].values())
  loopnests[head_row['key_column']]=(head_row,curr_offload_heads)
  offload_heads+=curr_offload_heads
  solution.update(top_data['per_loop_solution'])
 non_offloaded_heads=find_missed_branches(offload_heads,[x[0]for x in loopnests.values()],row_fit=IsLoopChecker())
 for x in non_offloaded_heads:
  solution[x].is_potential_offload_head=True
 return loopnests,offload_heads,non_offloaded_heads,solution
class OffloadCandidateChecker:
 def __init__(self,solution):
  self.solution=solution
 def __call__(self,row):
  est=None
  try:
   est=self.solution[row]['estimation']
  except KeyError:
   return True
  if not est:
   return False
  return est.get('is_offload_candidate')
class IsLoopChecker:
 def __init__(self):
  pass
 def __call__(self,row):
  return row.is_loop
def find_missed_branches(offl_heads,top_rows,row_fit=None):
 noffl_heads={x['key_column']:x for x in top_rows}
 for row in offl_heads:
  offl_key=row['key_column']
  curr_row=row
  stack={}
  while curr_row:
   curr_key=curr_row['key_column']
   stack[curr_key]=curr_row
   if curr_key in noffl_heads:
    del noffl_heads[curr_key]
    for k,v in stack.items():
     noffl_heads.update({x['key_column']:x for x in v.children if k!=offl_key and x['key_column']not in stack})
    break
   curr_row=curr_row.parent
 if row_fit is None:
  return noffl_heads.values()
 else:
  res=[]
  for idx,x in noffl_heads.items():
   stack=[x]
   while stack:
    curr_row=stack.pop()
    if row_fit(curr_row):
     res.append(curr_row)
    else:
     stack+=curr_row.children
  return res
def estimate_profitability(objective_fn,rows,baseline_estimations,estimations):
 baseline_time=sum(baseline_estimations[x]['total_base_time']for x in rows)
 baseline_weight=sum(objective_fn(baseline_estimations[x])for x in rows)
 for target_est in estimations:
  target_weight=objective_fn(target_est)
  yield(target_weight-baseline_weight,baseline_time/target_est['time'],)
