001/*
002 * Copyright (c) 2012-2021 Institut National des Sciences Appliquées de Lyon (INSA Lyon) and others
003 *
004 * This program and the accompanying materials are made available under the
005 * terms of the Eclipse Public License 2.0 which is available at
006 * http://www.eclipse.org/legal/epl-2.0.
007 *
008 * SPDX-License-Identifier: EPL-2.0
009 */
010
011package org.eclipse.golo.runtime;
012
013import gololang.FunctionReference;
014
015import java.lang.invoke.*;
016import java.lang.invoke.MethodHandles.Lookup;
017import java.lang.reflect.*;
018import java.util.Optional;
019
020import static java.lang.invoke.MethodType.methodType;
021import static java.lang.reflect.Modifier.isPrivate;
022import static java.lang.reflect.Modifier.isStatic;
023import static org.eclipse.golo.runtime.DecoratorsHelper.getDecoratedMethodHandle;
024import static org.eclipse.golo.runtime.DecoratorsHelper.isMethodDecorated;
025import static gololang.Messages.message;
026import static gololang.Messages.info;
027import static org.eclipse.golo.runtime.Extractors.checkDeprecation;
028import static org.eclipse.golo.runtime.NamedArgumentsHelper.*;
029
030public final class FunctionCallSupport {
031  private static final boolean DEBUG = Boolean.getBoolean("golo.debug.function-resolution");
032
033
034  private static void debug(String message, Object... args) {
035    if (DEBUG || gololang.Runtime.debugMode()) {
036      info("Function resolution: " + String.format(message, args));
037    }
038  }
039
040  private FunctionCallSupport() {
041    throw new UnsupportedOperationException("Don't instantiate invokedynamic bootstrap class");
042  }
043
044  public static class FunctionCallSite extends MutableCallSite {
045
046    final Lookup callerLookup;
047    final String name;
048    final boolean constant;
049    final String[] argumentNames;
050
051    FunctionCallSite(Lookup callerLookup, String name, MethodType type, boolean constant, String... argumentNames) {
052      super(type);
053      this.callerLookup = callerLookup;
054      this.name = name;
055      this.constant = constant;
056      this.argumentNames = argumentNames;
057    }
058  }
059
060  private static final MethodHandle FALLBACK;
061  private static final MethodHandle SAM_FILTER;
062  private static final MethodHandle FUNCTIONAL_INTERFACE_FILTER;
063
064  static {
065    try {
066      Lookup lookup = MethodHandles.lookup();
067      FALLBACK = lookup.findStatic(
068          FunctionCallSupport.class,
069          "fallback",
070          methodType(Object.class, FunctionCallSite.class, Object[].class));
071      SAM_FILTER = lookup.findStatic(
072          FunctionCallSupport.class,
073          "samFilter",
074          methodType(Object.class, Class.class, Object.class));
075      FUNCTIONAL_INTERFACE_FILTER = lookup.findStatic(
076          FunctionCallSupport.class,
077          "functionalInterfaceFilter",
078          methodType(Object.class, Lookup.class, Class.class, Object.class));
079    } catch (NoSuchMethodException | IllegalAccessException e) {
080      throw new Error("Could not bootstrap the required method handles", e);
081    }
082  }
083
084  public static Object samFilter(Class<?> type, Object value) {
085    if (value instanceof FunctionReference) {
086      return MethodHandleProxies.asInterfaceInstance(type, ((FunctionReference) value).handle());
087    }
088    return value;
089  }
090
091  public static Object functionalInterfaceFilter(Lookup caller, Class<?> type, Object value) throws Throwable {
092    if (value instanceof FunctionReference) {
093      return asFunctionalInterface(caller, type, ((FunctionReference) value).handle());
094    }
095    return value;
096  }
097
098  public static Object asFunctionalInterface(Lookup caller, Class<?> type, MethodHandle handle) throws Throwable {
099    for (Method method : type.getMethods()) {
100      if (!method.isDefault() && !isStatic(method.getModifiers())) {
101        MethodType lambdaType = methodType(method.getReturnType(), method.getParameterTypes());
102        CallSite callSite = LambdaMetafactory.metafactory(
103            caller,
104            method.getName(),
105            methodType(type),
106            lambdaType,
107            handle,
108            lambdaType);
109        return callSite.dynamicInvoker().invoke();
110      }
111    }
112    throw new RuntimeException(message("handle_conversion_failed", handle, type));
113  }
114
115  public static CallSite bootstrap(Lookup caller, String name, MethodType type, Object... bsmArgs) throws IllegalAccessException, ClassNotFoundException {
116    boolean constant = ((int) bsmArgs[0]) == 1;
117    String[] argumentNames = new String[bsmArgs.length - 1];
118    for (int i = 0; i < bsmArgs.length - 1; i++) {
119      argumentNames[i] = (String) bsmArgs[i + 1];
120    }
121    FunctionCallSite callSite = new FunctionCallSite(
122        caller,
123        name.replaceAll("#", "\\."),
124        type,
125        constant,
126        argumentNames);
127    MethodHandle fallbackHandle = FALLBACK
128        .bindTo(callSite)
129        .asCollector(Object[].class, type.parameterCount())
130        .asType(type);
131    callSite.setTarget(fallbackHandle);
132    return callSite;
133  }
134
135  public static Object fallback(FunctionCallSite callSite, Object[] args) throws Throwable {
136    String functionName = callSite.name;
137    MethodType type = callSite.type();
138    Lookup caller = callSite.callerLookup;
139    Class<?> callerClass = caller.lookupClass();
140    String[] argumentNames = callSite.argumentNames;
141
142    MethodHandle handle = null;
143    AccessibleObject result = null;
144    if (!functionName.contains(".")) {
145      result = findStaticMethodOrField(callerClass, callerClass, functionName, args);
146    }
147    if (result == null) {
148      result = findClassWithStaticMethodOrField(callerClass, functionName, args);
149    }
150    if (result == null) {
151      result = findClassWithStaticMethodOrFieldFromImports(callerClass, functionName, args);
152    }
153    if (result == null) {
154      result = findClassWithConstructor(callerClass, functionName, args);
155    }
156    if (result == null) {
157      result = findClassWithConstructorFromImports(callerClass, functionName, args);
158    }
159    if (result == null) {
160      throw new NoSuchMethodError(functionName + type.toMethodDescriptorString());
161    }
162
163    Class<?>[] types = null;
164    if (result instanceof Method) {
165      Method method = (Method) result;
166      checkLocalFunctionCallFromSameModuleAugmentation(method, callerClass.getName());
167      if (isMethodDecorated(method)) {
168        handle = getDecoratedMethodHandle(caller, method, type.parameterCount());
169      } else {
170        types = method.getParameterTypes();
171        handle = caller.unreflect(method);
172        if (method.isAnnotationPresent(WithCaller.class)) {
173          handle = handle.bindTo(callerClass);
174        }
175        //TODO: improve varargs support on named arguments. Matching the last param type + according argument
176        if (isVarargsWithNames(method, types, args, argumentNames)) {
177          handle = handle.asFixedArity().asType(type);
178        } else {
179          handle = handle.asType(type);
180        }
181      }
182      handle = reorderArguments(method, handle, argumentNames);
183    } else if (result instanceof Constructor) {
184      Constructor<?> constructor = (Constructor<?>) result;
185      types = constructor.getParameterTypes();
186      if (constructor.isVarArgs() && TypeMatching.isLastArgumentAnArray(types.length, args)) {
187        handle = caller.unreflectConstructor(constructor).asFixedArity().asType(type);
188      } else {
189        handle = caller.unreflectConstructor(constructor).asType(type);
190      }
191    } else {
192      Field field = (Field) result;
193      handle = caller.unreflectGetter(field).asType(type);
194    }
195    handle = insertSAMFilter(handle, callSite.callerLookup, types, 0);
196
197    if (callSite.constant) {
198      Object constantValue = handle.invokeWithArguments(args);
199      MethodHandle constant;
200      if (constantValue == null) {
201        constant = MethodHandles.constant(Object.class, null);
202      } else {
203        constant = MethodHandles.constant(constantValue.getClass(), constantValue);
204      }
205      constant = MethodHandles.dropArguments(constant, 0, type.parameterArray());
206      callSite.setTarget(constant.asType(type));
207      return constantValue;
208    } else {
209      callSite.setTarget(handle);
210      return handle.invokeWithArguments(args);
211    }
212  }
213
214  private static boolean isVarargsWithNames(Method method, Class<?>[] types, Object[] args, String[] argumentNames) {
215    return method.isVarArgs()
216      && (
217          TypeMatching.isLastArgumentAnArray(types.length, args)
218          || argumentNames.length > 0);
219  }
220
221  public static MethodHandle reorderArguments(Method method, MethodHandle handle, String[] argumentNames) {
222    return NamedArgumentsHelper.reorderArguments(
223        method.getName(),
224        getParameterNames(method),
225        handle,
226        argumentNames, 0, 0);
227  }
228
229  public static MethodHandle insertSAMFilter(MethodHandle handle, Lookup caller, Class<?>[] types, int startIndex) {
230    if (types != null) {
231      for (int i = 0; i < types.length; i++) {
232        if (TypeMatching.isSAM(types[i])) {
233          handle = MethodHandles.filterArguments(handle, startIndex + i, SAM_FILTER.bindTo(types[i]));
234        } else if (TypeMatching.isFunctionalInterface(types[i])) {
235          handle = MethodHandles.filterArguments(
236              handle,
237              startIndex + i,
238              FUNCTIONAL_INTERFACE_FILTER.bindTo(caller).bindTo(types[i]));
239        }
240      }
241    }
242    return handle;
243  }
244
245  private static void checkLocalFunctionCallFromSameModuleAugmentation(Method method, String callerClassName) {
246    if (isPrivate(method.getModifiers()) && callerClassName.contains("$")) {
247      String prefix = callerClassName.substring(0, callerClassName.indexOf("$"));
248      if (method.getDeclaringClass().getName().equals(prefix)) {
249        method.setAccessible(true);
250      }
251    }
252  }
253
254  private static AccessibleObject findClassWithConstructorFromImports(Class<?> callerClass, String classname, Object[] args) {
255    String[] imports = Module.imports(callerClass);
256    for (String imported : imports) {
257      AccessibleObject result = findClassWithConstructor(
258          callerClass,
259          mergeImportAndCall(imported, classname),
260          args);
261      if (result != null) {
262        return result;
263      }
264    }
265    return null;
266  }
267
268  private static AccessibleObject findClassWithConstructor(Class<?> callerClass, String classname, Object[] args) {
269    debug("looking for constructor for `%s`", classname);
270    try {
271      Class<?> targetClass = Class.forName(classname, true, callerClass.getClassLoader());
272      for (Constructor<?> constructor : targetClass.getConstructors()) {
273        if (TypeMatching.argumentsMatch(constructor, args)) {
274          debug("constructor found");
275          return checkDeprecation(callerClass, constructor);
276        }
277      }
278    } catch (ClassNotFoundException ignored) {
279      // ignored to try the next strategy
280    }
281    return null;
282  }
283
284  static String mergeImportAndCall(String importName, String functionName) {
285    if (importName == null || importName.isEmpty()) {
286      return functionName;
287    }
288    String[] importParts = importName.split("\\.");
289    String[] functionParts = functionName.split("\\.");
290    StringBuilder merged = new StringBuilder();
291    int fidx = 0;
292    for (String imp : importParts) {
293      if (imp.equals(functionParts[fidx])) {
294        fidx++;
295      } else if (fidx > 0) {
296        return importName + '.' + functionName;
297      }
298      if (merged.length() != 0) {
299        merged.append('.');
300      }
301      merged.append(imp);
302    }
303    while (fidx < functionParts.length - 1) {
304      merged.append('.').append(functionParts[fidx]);
305      fidx++;
306    }
307    if (fidx == functionParts.length - 1) {
308      merged.append('.').append(functionParts[fidx]);
309    }
310    return merged.toString();
311  }
312
313  private static AccessibleObject findClassWithStaticMethodOrFieldFromImports(Class<?> callerClass, String functionName, Object[] args) {
314    AccessibleObject result = null;
315    if (functionName.contains(".")) {
316      result = findClassWithStaticMethodOrField(
317          callerClass,
318          mergeImportAndCall(callerClass.getCanonicalName(), functionName),
319          args);
320      if (result != null) {
321        return result;
322      }
323    }
324    String[] imports = Module.imports(callerClass);
325    for (String importedClassName : imports) {
326      result = findClassWithStaticMethodOrField(
327          callerClass,
328          mergeImportAndCall(importedClassName, functionName),
329          args);
330      if (result != null) {
331        return result;
332      }
333    }
334    return null;
335  }
336
337  private static AccessibleObject findClassWithStaticMethodOrField(Class<?> callerClass, String functionName, Object[] args) {
338    int methodClassSeparatorIndex = functionName.lastIndexOf(".");
339    if (methodClassSeparatorIndex >= 0) {
340      String className = functionName.substring(0, methodClassSeparatorIndex);
341      String methodName = functionName.substring(methodClassSeparatorIndex + 1);
342      debug("looking for function `%s` in named `%s`", methodName, className);
343      try {
344        Class<?> targetClass = Class.forName(className, true, callerClass.getClassLoader());
345        return findStaticMethodOrField(callerClass, targetClass, methodName, args);
346      } catch (ClassNotFoundException ignored) {
347        // ignored to try the next strategy
348        Warnings.unavailableClass(className, callerClass.getName());
349      }
350    }
351    return null;
352  }
353
354  private static AccessibleObject findStaticMethodOrField(Class<?> caller, Class<?> klass, String name, Object[] arguments) {
355    debug("looking for function `%s` in loaded class `%s`", name, klass.getCanonicalName());
356    Optional<Method> meth = Extractors.getMethods(klass)
357      .filter(m -> methodMatches(caller, name, arguments, m, m.isVarArgs()))
358      .map(m -> checkDeprecation(caller, m))
359      .findFirst();
360    if (meth.isPresent()) {
361      debug("method found");
362      return meth.get();
363    }
364    if (arguments.length == 0) {
365      Optional<Field> f = Extractors.getFields(klass)
366        .filter(o -> fieldMatches(name, o))
367        .map(o -> checkDeprecation(caller, o))
368        .findFirst();
369      return f.orElse(null);
370    }
371    return null;
372  }
373
374  private static boolean methodMatches(Class<?> caller, String name, Object[] arguments, Method method, boolean varargs) {
375    return methodMatches(caller, name, arguments, method, varargs, true);
376  }
377
378  private static boolean methodMatches(Class<?> caller, String name, Object[] arguments, Method method, boolean varargs, boolean tryCaller) {
379    if (!method.getName().equals(name) || !isStatic(method.getModifiers())) { return false; }
380    if (isMethodDecorated(method)) { return true; }
381    if (TypeMatching.argumentsMatch(method, arguments, varargs)) { return true; }
382    if (method.isAnnotationPresent(WithCaller.class) && tryCaller) {
383      Object[] argsWithCaller = new Object[arguments.length + 1];
384      argsWithCaller[0] = caller;
385      System.arraycopy(arguments, 0, argsWithCaller, 1, arguments.length);
386      return methodMatches(caller, name, argsWithCaller, method, varargs, false);
387    }
388    return false;
389  }
390
391  private static boolean fieldMatches(String name, Field field) {
392    return field.getName().equals(name) && isStatic(field.getModifiers());
393  }
394}