detector = hub.Module( “https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1" )
m = hub.Module(“/tmp/text-embedding”) embeddings = m(sentences)
# compute statistics for a CSV file
train_stats = tfdv.generate_statistics_from_csv(TRAIN_DATA)
# compute statistics for TF Record files
train_stats = tfdv.generate_statistics_from_tfrecord(TRAIN_DATA)
# Infer schema based on statistics
schema = tfdv.infer_schema(train_stats)
# Display schema inline in table format
tfdv.display_schema(schema)
# Compute statistics over a new set of data
new_stats = tfdv.generate_statistics_from_csv(NEW_DATA)
# Compare how new data conforms to the schema
anomalies = tfdv.validate_statistics(new_stats, schema)
# Display anomalies inline
tfdv.display_anomalies(anomalies)
feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
schema = dataset_schema.from_feature_spec(feature_spec)
Na Zalando Research, assim como na maioria dos departamentos de pesquisa de IA, sabemos da importância da agilidade nos testes e na criação de protótipos. Com os conjuntos de dados ficando cada vez maiores, é bom saber como treinar modelos de aprendizado profundo com velocidade e eficiência usando os recursos compartilhados disponíveis.
A API Estimators do TensorFlow é útil para treinar modelos em ambientes distribuídos com várias GPUs. Aqui, apresentaremos esse fluxo de trabalho treinando um estimador personalizado criado com tf.keras para o pequeno conjunto de dados Fashion-MNIST. Depois mostraremos um caso de uso mais prático.
Nota: a equipe do TensorFlow também está trabalhando em um recurso novo e incrível (no momento da criação deste artigo, ainda está na fase de master) que permite treinar um modelo tf.keras sem precisar transformá-lo em um estimador, usando apenas algumas poucas linhas de código a mais. Esse fluxo de trabalho também é ótimo. Abaixo, nos limitaremos à Estimators API. Escolha o que for melhor para você.
TL;DR: na essência, o que queremos lembrar é que um tf.keras.Model pode ser treinado com a API tf.estimator convertendo-o em um objeto tf.estimator.Estimator pelo método tf.keras.estimator.model_to_estimator. Depois da conversão, podemos aplicar as ferramentas que o Estimators oferece para treinar em diferentes configurações de hardware.
Você pode baixar o código desta postagem neste bloco de anotações e executá-lo por conta própria.
import osimport time
#!pip install -q -U tensorflow-gpuimport tensorflow as tf
import numpy as np
Usaremos o conjunto de dados Fashion-MNIST, uma substituição direta do MNIST, que contém milhares de imagens em escala de cinza dos artigos de moda da Zalando. Obter os dados de treinamento e teste é bem simples:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
Queremos converter os valores de pixels dessas imagens de um número entre 0 e 255 em um número entre 0 e 1 e converter o conjunto de dados para o formato [B,H,W,C], onde "B" é o número de imagens do lote, "H" e "W" são a altura e o comprimento, e "C" é o número de canais (1 para escala de cinza) do nosso conjunto:
TRAINING_SIZE = len(train_images)TEST_SIZE = len(test_images)
train_images = np.asarray(train_images, dtype=np.float32) / 255# Convert the train images and add channelstrain_images = train_images.reshape((TRAINING_SIZE, 28, 28, 1))
test_images = np.asarray(test_images, dtype=np.float32) / 255# Convert the test images and add channelstest_images = test_images.reshape((TEST_SIZE, 28, 28, 1))
Depois, precisamos converter os rótulos de um ID de número inteiro (por exemplo, 2 ou suéter) para uma codificação tipo one-hot (por exemplo, 0,0,1,0,0,0,0,0,0,0). Para fazer isso, usaremos a função tf.keras.utils.to_categorical:
# How many categories we are predicting from (0-9)LABEL_DIMENSIONS = 10
train_labels = tf.keras.utils.to_categorical(train_labels, LABEL_DIMENSIONS)
test_labels = tf.keras.utils.to_categorical(test_labels, LABEL_DIMENSIONS)
# Cast the labels to floats, needed latertrain_labels = train_labels.astype(np.float32)test_labels = test_labels.astype(np.float32)
Criaremos nossa rede neural usando a API Keras Functional. O Keras é uma API de alto nível para criar e treinar modelos de aprendizado profundo que é intuitiva, modular e fácil de ampliar. O tf.keras é a implementação do TensorFlow para essa API e oferece suporte a recursos como execução antecipada, canais tf.data e estimadores.
Para a arquitetura, usaremos os ConvNets. Em um nível bem resumido, os ConvNets são pilhas de camadas convolucionais (Conv2D) e camadas de agrupamento, ou pooling (MaxPooling2D). Mas o mais importante é que, para cada exemplo de treinamento, eles recebem tensores 3D de formato (altura, largura e canais). No caso de imagens em escala de cinza, começam com channels=1 e retornam um tensor 3D.
Portanto, depois da parte do ConvNet, precisaremos nivelar (Flatten) o tensor e adicionar camadas de densidade (Dense) em que a última retorna um vetor de tamanho LABEL_DIMENSIONS com a ativação de tf.nn.softmax:
inputs = tf.keras.Input(shape=(28,28,1)) # Returns a placeholder
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu)(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation=tf.nn.relu)(x)predictions = tf.keras.layers.Dense(LABEL_DIMENSIONS, activation=tf.nn.softmax)(x)
Agora, podemos definir o nosso modelo, selecionar o otimizador (escolhemos um do TensorFlow, e não do tf.keras.optimizers) e compilá-lo:
model = tf.keras.Model(inputs=inputs, outputs=predictions)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
Para criar um estimador no modelo Keras compilado, chamamos o método model_to_estimator. Veja que o estado inicial do modelo Keras é preservado no estimador criado.
Ok, mas o que os estimadores têm de tão bom assim? Podemos começar com o seguinte:
E como fazemos para treinar nosso modelo tf.keras simples para usar várias GPUs? Podemos usar o paradigma tf.contrib.distribute.MirroredStrategy, que faz replicação no gráfico com treinamento síncrono. Veja esta conversa no treinamento de Distributed TensorFlow para saber mais sobre essa estratégia.
Essencialmente, cada GPU de worker tem uma cópia da rede e recebe um subconjunto dos dados com os quais calcula os gradientes locais e, em seguida, espera que todos os workers terminem de maneira síncrona. Depois, os workers comunicam seus gradientes locais entre si com uma operação "Ring All-reduce" que, normalmente, é otimizada para reduzir a largura de banda da rede e aumentar a produtividade. Quando todos os gradientes chegarem, cada worker calculará a média deles e atualizará o seu parâmetro. A próxima etapa começará em seguida. Isso é ideal em casos com várias GPUs em um único nó conectado por alguma interconexão de alta velocidade.
Para usar essa estratégia, primeiro criamos um estimador no modelo tf.keras compilado e atribuímos a ele a configuração MirroredStrategy pelo RunConfig. Por padrão, essa configuração usará todas as GPUs, mas você também pode definir uma opção num_gpus para usar um número específico de GPUs:
NUM_GPUS = 2
strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model, config=config)
Para canalizar os dados para os estimadores, precisamos configurar uma função de importação de dados que retorne um conjunto tf.data de lotes (imagens,rótulos) dos nossos dados. A função abaixo recebe matrizes "numpy" e retorna o conjunto dados por um processo de ETL.
Veja que, no fim, também chamamos o método "prefetch", que carrega os dados em buffer para as GPUs enquanto elas estão treinando, preparando o próximo lote e esperando pelas GPUs, em vez de fazer as GPUs esperarem os dados a cada iteração. Pode ser que a GPU ainda não seja totalmente utilizada, e, para melhorar isso, podemos usar versões combinadas das operações de transformação, como "shuffle_and_repeat", em vez de duas operações distintas, mas eu considerei apenas o caso mais simples.
def input_fn(images, labels, epochs, batch_size): # Convert the inputs to a Dataset. (E) ds = tf.data.Dataset.from_tensor_slices((images, labels))
# Shuffle, repeat, and batch the examples. (T) SHUFFLE_SIZE = 5000 ds = ds.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size) ds = ds.prefetch(2)
# Return the dataset. (L) return ds
Primeiro, definiremos uma classe SessionRunHook para registrar os momentos de cada iteração do método do gradiente descendente estocástico:
class TimeHistory(tf.train.SessionRunHook): def begin(self): self.times = []
def before_run(self, run_context): self.iter_time_start = time.time()
def after_run(self, run_context, run_values): self.times.append(time.time() - self.iter_time_start)
Agora, chegamos na parte boa! Podemos chamar a função "train" do nosso estimador fornecendo a "input_fn" que definimos (com o tamanho do lote e o número de eras para as quais queremos treinar) e uma instância de TimeHistory por meio do argumento "hooks":
time_hist = TimeHistory()
BATCH_SIZE = 512EPOCHS = 5
estimator.train(lambda:input_fn(train_images, train_labels, epochs=EPOCHS, batch_size=BATCH_SIZE), hooks=[time_hist])
Graças ao nosso hook de temporização, agora podemos usá-lo para calcular o tempo total do treinamento, além da média de imagens que treinamos por segundo (a produtividade média):
total_time = sum(time_hist.times)print(f"total time with {NUM_GPUS} GPU(s): {total_time} seconds")
avg_time_per_batch = np.mean(time_hist.times)print(f"{BATCH_SIZE*NUM_GPUS/avg_time_per_batch} images/second with {NUM_GPUS} GPU(s)")
Para checar o desempenho do nosso modelo, chamamos o método "evaluate" no estimador:
estimator.evaluate(lambda:input_fn(test_images, test_labels, epochs=1, batch_size=BATCH_SIZE))
Para testar o desempenho de escalonamento em conjuntos de dados maiores, usamos as imagens de OCT retinais, um dos muitos conjuntos de dados grandes do Kaggle. Esse conjunto tem imagens de raios-X de cortes transversais de retinas de humanos vivos agrupadas em quatro categorias, NORMAL, CNV, DME e DRUSEN:
O conjunto de dados tem um total de 84.495 imagens de raios-X em JPEG, normalmente com 512 x 496 pixels, e pode ser baixado pelo CLI do kaggle:
#!pip install kaggle#!kaggle datasets download -d paultimothymooney/kermany2018
Depois de baixar o conjunto de treinamento e teste, as classes de imagem estarão em suas respectivas pastas. Definiremos um padrão da seguinte forma:
labels = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
train_folder = os.path.join('OCT2017', 'train', '**', '*.jpeg')test_folder = os.path.join('OCT2017', 'test', '**', '*.jpeg')
Depois, criamos a função "input" do nosso estimador, que recebe qualquer padrão de arquivos e retorna imagens redimensionadas e rótulos codificados do tipo one-hot como um tf.data.Dataset. Desta vez, seguimos as práticas recomendadas do Guia de alto desempenho dos canais de entrada. Veja que se o "buffer_size" do prefetch for "None", o TensorFlow usará automaticamente um tamanho de buffer de pré-busca ideal:
Desta vez, usaremos um VGG16 pré-treinado e retreinaremos somente suas últimas 5 camadas para treinar este modelo:
keras_vgg16 = tf.keras.applications.VGG16(input_shape=(224,224,3), include_top=False)
output = keras_vgg16.outputoutput = tf.keras.layers.Flatten()(output)prediction = tf.keras.layers.Dense(len(labels), activation=tf.nn.softmax)(output)
model = tf.keras.Model(inputs=keras_vgg16.input, outputs=prediction)
for layer in keras_vgg16.layers[:-4]: layer.trainable = False
Agora, temos tudo o que precisamos e podemos prosseguir, como mencionado acima, e treinar o nosso modelo em alguns minutos usando NUM_GPUS GPUs:
model.compile(loss='categorical_crossentropy', optimizer=tf.train.AdamOptimizer(), metrics=['accuracy'])
NUM_GPUS = 2strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)config = tf.estimator.RunConfig(train_distribute=strategy)estimator = tf.keras.estimator.model_to_estimator(model, config=config)
BATCH_SIZE = 64EPOCHS = 1
estimator.train(input_fn=lambda:input_fn(train_folder, labels, shuffle=True, batch_size=BATCH_SIZE, buffer_size=2048, num_epochs=EPOCHS, prefetch_buffer_size=4), hooks=[time_hist])
Depois do treinamento, podemos avaliar a precisão no conjunto de teste, que deve ser de cerca de 95% (nada mau para um ponto de referência inicial 😀):
estimator.evaluate(input_fn=lambda:input_fn(test_folder, labels, shuffle=False, batch_size=BATCH_SIZE, buffer_size=1024, num_epochs=1))
Mostramos acima como é fácil treinar modelos Keras de aprendizado profundo com várias GPUs usando a API Estimators, como criar um canal de entrada que segue as práticas recomendadas para ter boa utilização de recursos (escalonamento linear) e como marcar o tempo da produtividade do treinamento com hooks.
É bom destacar que, no final, o ponto mais importante são os erros do conjunto de teste. Você pode notar que a precisão do conjunto de teste cai quando aumentamos o NUM_GPUS. Um dos motivos pode ser o fato de que MirroredStrategy realmente treina com um tamanho de lote de BATCH_SIZE*NUM_GPUS, o que pode exigir o ajuste de BATCH_SIZE ou da taxa de aprendizado quando se usa mais GPUs. Neste caso, separei todos os outros hiperparâmetros da constante NUM_GPUS para facilitar os traçados. No entanto, em um cenário real, será preciso ajustá-los.
O tamanho do conjunto de dados, além do tamanho do modelo, afeta a capacidade de dimensionamento desses esquemas. As GPUs têm baixa largura de banda na leitura e gravação de dados de baixo volume, especialmente as GPUs mais antigas, como a K80, e podem ser responsáveis pelos traçados de Fashion-MNIST acima.
Quero agradecer à equipe do TensorFlow, especialmente a Josh Gordon, e a todos da Zalando Research pela ajuda na correção do rascunho, especialmente Duncan Blythe, Gokhan Yildirim e Sebastian Heinz.
// Send a binary message from Dart to the platform.
final WriteBuffer buffer = WriteBuffer() ..putFloat64(3.1415) ..putInt32(12345678); final ByteData message = buffer.done(); await BinaryMessages.send('foo', message); print('Message sent, reply ignored');
// Receive binary messages from Dart on Android. // This code can be added to a FlutterActivity subclass, typically // in onCreate.
flutterView.setMessageHandler("foo") { message, reply -> message.order(ByteOrder.nativeOrder()) val x = message.double val n = message.int Log.i("MSG", "Received: $x and $n") reply.reply(null) }
// Receive binary messages from Dart on iOS. // This code can be added to a FlutterAppDelegate subclass, // typically in application:didFinishLaunchingWithOptions:.
let flutterView = window?.rootViewController as! FlutterViewController; flutterView.setMessageHandlerOnChannel("foo") { (message: Data!, reply: FlutterBinaryReply) -> Void in let x : Float64 = message.subdata(in: 0..<8) .withUnsafeBytes { $0.pointee } let n : Int32 = message.subdata(in: 8..<12) .withUnsafeBytes { $0.pointee } os_log("Received %f and %d", x, n) reply(nil) }
// Send a binary message from Android.
val message = ByteBuffer.allocateDirect(12) message.putDouble(3.1415) message.putInt(123456789) flutterView.send("foo", message) { _ -> Log.i("MSG", "Message sent, reply ignored") }
// Send a binary message from iOS.
var message = Data(capacity: 12) var x : Float64 = 3.1415 var n : Int32 = 12345678 message.append(UnsafeBufferPointer(start: &x, count: 1)) message.append(UnsafeBufferPointer(start: &n, count: 1)) flutterView.send(onChannel: "foo", message: message) {(_) -> Void in os_log("Message sent, reply ignored") }
// Receive binary messages from the platform.
BinaryMessages.setMessageHandler('foo', (ByteData message) async { final ReadBuffer readBuffer = ReadBuffer(message); final double x = readBuffer.getFloat64(); final int n = readBuffer.getInt32(); print('Received $x and $n'); return null; });
// String messages
// Dart side
const channel = BasicMessageChannel<String>('foo', StringCodec());
// Send message to platform and receive reply.final String reply = await channel.send('Hello, world'); print(reply);
// Receive messages from platform and send replies. channel.setMessageHandler((String message) async { print('Received: $message'); return 'Hi from Dart'; });
// Android side
val channel = BasicMessageChannel<String>( flutterView, "foo", StringCodec.INSTANCE)
// Send message to Dart and receive reply.channel.send("Hello, world") { reply -> Log.i("MSG", reply) }
// Receive messages from Dart and send replies.channel.setMessageHandler { message, reply -> Log.i("MSG", "Received: $message") reply.reply("Hi from Android") }
// iOS side
let channel = FlutterBasicMessageChannel( name: "foo", binaryMessenger: controller, codec: FlutterStringCodec.sharedInstance())
// Send message to Dart and receive reply.channel.sendMessage("Hello, world") {(reply: Any?) -> Void in os_log("%@", type: .info, reply as! String) }
// Receive messages from Dart and send replies. channel.setMessageHandler { (message: Any?, reply: FlutterReply) -> Void in os_log("Received: %@", type: .info, message as! String) reply("Hi from iOS") }
const codec = StringCodec();
// Send message to platform and receive reply.final String reply = codec.decodeMessage( await BinaryMessages.send( 'foo', codec.encodeMessage('Hello, world'), ), ); print(reply);
// Receive messages from platform and send replies. BinaryMessages.setMessageHandler('foo', (ByteData message) async { print('Received: ${codec.decodeMessage(message)}'); return codec.encodeMessage('Hi from Dart'); });
final String reply1 = await channel.send(msg1); final int reply2 = await channel.send(msg2);
final List<String> reply3 = await channel.send(msg3); // Fails.final List<dynamic> reply3 = await channel.send(msg3); // Works.
Future<String> greet() => channel.send('hello, world'); // Fails.Future<String> greet() async { // Works. final String reply = await channel.send('hello, world'); return reply; }
// Invocation of platform methods, simple case.
// Dart side.
const channel = MethodChannel('foo'); final String greeting = await channel.invokeMethod('bar', 'world'); print(greeting);
// Android side.
val channel = MethodChannel(flutterView, "foo") channel.setMethodCallHandler { call, result -> when (call.method) { "bar" -> result.success("Hello, ${call.arguments}") else -> result.notImplemented() } }
// iOS side.
let channel = FlutterMethodChannel( name: "foo", binaryMessenger: flutterView) channel.setMethodCallHandler { (call: FlutterMethodCall, result: FlutterResult) -> Void in switch (call.method) { case "bar": result("Hello, \(call.arguments as! String)") default: result(FlutterMethodNotImplemented) } }
const codec = StandardMethodCodec();
final ByteData reply = await BinaryMessages.send( 'foo', codec.encodeMethodCall(MethodCall('bar', 'world')), ); if (reply == null) throw MissingPluginException(); else print(codec.decodeEnvelope(reply));
await channel.invokeMethod('bar'); await channel.invokeMethod('bar', <dynamic>['world', 42, pi]); await channel.invokeMethod('bar', <String, dynamic>{ name: 'world', answer: 42, math: pi, }));
// Method calls with error handling.
const channel = MethodChannel('foo');
// Invoke a platform method.const name = 'bar'; // or 'baz', or 'unknown' const value = 'world'; try { print(await channel.invokeMethod(name, value)); } on PlatformException catch(e) { print('$name failed: ${e.message}'); } on MissingPluginException { print('$name not implemented'); }
// Receive method invocations from platform and return results. channel.setMethodCallHandler((MethodCall call) async { switch (call.method) { case 'bar': return 'Hello, ${call.arguments}'; case 'baz': throw PlatformException(code: '400', message: 'This is bad'); default: throw MissingPluginException(); } });
val channel = MethodChannel(flutterView, "foo")
// Invoke a Dart method. val name = "bar" // or "baz", or "unknown" val value = "world" channel.invokeMethod(name, value, object: MethodChannel.Result { override fun success(result: Any?) { Log.i("MSG", "$result") } override fun error(code: String?, msg: String?, details: Any?) { Log.e("MSG", "$name failed: $msg") } override fun notImplemented() { Log.e("MSG", "$name not implemented") } })
// Receive method invocations from Dart and return results. channel.setMethodCallHandler { call, result -> when (call.method) { "bar" -> result.success("Hello, ${call.arguments}") "baz" -> result.error("400", "This is bad", null) else -> result.notImplemented() } }
let channel = FlutterMethodChannel( name: "foo", binaryMessenger: flutterView)
// Invoke a Dart method.let name = "bar" // or "baz", or "unknown" let value = "world" channel.invokeMethod(name, arguments: value) { (result: Any?) -> Void in if let error = result as? FlutterError { os_log("%@ failed: %@", type: .error, name, error.message!) } else if FlutterMethodNotImplemented.isEqual(result) { os_log("%@ not implemented", type: .error, name) } else { os_log("%@", type: .info, result as! NSObject) } }
// Receive method invocations from Dart and return results.channel.setMethodCallHandler { (call: FlutterMethodCall, result: FlutterResult) -> Void in switch (call.method) { case "bar": result("Hello, \(call.arguments as! String)") case "baz": result(FlutterError( code: "400", message: "This is bad", details: nil)) default: result(FlutterMethodNotImplemented) }
// Consuming events on the Dart side.
const channel = EventChannel('foo');
channel.receiveBroadcastStream().listen((dynamic event) { print('Received event: $event'); }, onError: (dynamic error) { print('Received error: ${error.message}'); });
// Producing sensor events on Android.
// SensorEventListener/EventChannel adapter.class SensorListener(private val sensorManager: SensorManager) : EventChannel.StreamHandler, SensorEventListener { private var eventSink: EventChannel.EventSink? = null // EventChannel.StreamHandler methods override fun onListen( arguments: Any?, eventSink: EventChannel.EventSink?) { this.eventSink = eventSink registerIfActive() } override fun onCancel(arguments: Any?) { unregisterIfActive() eventSink = null } // SensorEventListener methods. override fun onSensorChanged(event: SensorEvent) { eventSink?.success(event.values) } override fun onAccuracyChanged(sensor: Sensor?, accuracy: Int) { if (accuracy == SensorManager.SENSOR_STATUS_ACCURACY_LOW) eventSink?.error("SENSOR", "Low accuracy detected", null) }
// Lifecycle methods. fun registerIfActive() { if (eventSink == null) return sensorManager.registerListener( this, sensorManager.getDefaultSensor(Sensor.TYPE_GYROSCOPE), SensorManager.SENSOR_DELAY_NORMAL) } fun unregisterIfActive() { if (eventSink == null) return sensorManager.unregisterListener(this) } }
// Use of the above class in an Activity.class MainActivity: FlutterActivity() { var sensorListener: SensorListener? = null override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) GeneratedPluginRegistrant.registerWith(this) sensorListener = SensorListener( getSystemService(Context.SENSOR_SERVICE) as SensorManager) val channel = EventChannel(flutterView, "foo") channel.setStreamHandler(sensorListener) } override fun onPause() { sensorListener?.unregisterIfActive() super.onPause() } override fun onResume() { sensorListener?.registerIfActive() super.onResume() } }
// Dart: we expect to receive a non-null List of integers. for (final int n in await channel.invokeMethod('getFib', 100)) { print(n * n); }
// Android: we expect non-null name and age arguments for // asynchronous processing, delivered in a string-keyed map. channel.setMethodCallHandler { call, result -> when (call.method) { "bar" -> { val name : String = call.argument("name") val age : Int = call.argument("age") process(name, age, result) } else -> result.notImplemented() } } : fun process(name: String, age: Int, result: Result) { ... }
test('gets greeting from platform', () async { const channel = MethodChannel('foo'); channel.setMockMethodCallHandler((MethodCall call) async { if (call.method == 'bar') return 'Hello, ${call.arguments}'; throw MissingPluginException(); }); expect(await hello('world'), 'Platform says: Hello, world'); });
test('collects incoming arguments', () async { const channel = MethodChannel('foo'); final hello = Hello(); final String result = await handleMockCall( channel, MethodCall('bar', 'world'), ); expect(result, contains('Hello, world')); expect(hello.collectedArguments, contains('world')); });
// Could be made an instance method on class MethodChannel. Future<dynamic> handleMockCall( MethodChannel channel, MethodCall call, ) async { dynamic result; await BinaryMessages.handlePlatformMessage( channel.name, channel.codec.encodeMethodCall(call), (ByteData reply) { if (reply == null) throw MissingPluginException(); result = channel.codec.decodeEnvelope(reply); }, ); return result; }