No public description

PiperOrigin-RevId: 611791844
This commit is contained in:
Sebastian Schmidt 2024-03-01 06:51:33 -08:00 committed by Copybara-Service
parent faacda8134
commit 81b81b045b
20 changed files with 1928 additions and 885 deletions

View File

@ -52,6 +52,7 @@ def mediapipe_ts_library(
"@npm//@types/node",
"@npm//@types/offscreencanvas",
"@npm//@types/google-protobuf",
"@npm//@webgpu/types",
],
testonly = testonly,
declaration = True,

View File

@ -0,0 +1,40 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:__subpackages__",
])
licenses(["notice"])
mediapipe_proto_library(
name = "detokenizer_calculator_proto",
srcs = ["detokenizer_calculator.proto"],
)
mediapipe_proto_library(
name = "tokenizer_calculator_proto",
srcs = ["tokenizer_calculator.proto"],
)
mediapipe_proto_library(
name = "llm_gpu_calculator_proto",
srcs = ["llm_gpu_calculator.proto"],
deps = [
"//mediapipe/tasks/cc/genai/inference/proto:llm_file_metadata_proto",
"//mediapipe/tasks/cc/genai/inference/proto:llm_params_proto",
],
)

View File

@ -0,0 +1,34 @@
// Copyright 2024 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package odml.infra.proto;
option java_package = "com.google.odml.infra.proto";
option java_outer_classname = "DetokenizerCalculatorOptionsProto";
message DetokenizerCalculatorOptions {
// The path to the SentencePiece model file.
string spm_model_file = 1;
// A set of tokens to stop the decoding process whenever they appear in the
// result.
repeated string stop_tokens = 4;
// How many sets of input IDs need to be detokenized.
int32 num_output_heads = 5;
reserved 2, 3;
}

View File

@ -0,0 +1,72 @@
// Copyright 2024 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package odml.infra.proto;
import "mediapipe/tasks/cc/genai/inference/proto/llm_file_metadata.proto";
import "mediapipe/tasks/cc/genai/inference/proto/llm_params.proto";
option java_package = "com.google.odml.infra.proto";
option java_outer_classname = "LlmGpuCalculatorOptionsProto";
message LlmGpuCalculatorOptions {
string weight_path = 1;
// Use LlmParameters instead.
reserved 2, 3, 4, 5, 6, 7, 8, 9;
message GpuModelInfo {
// If set True, use float16 precision in computation.
bool allow_precision_loss = 1;
bool enable_fast_tuning = 2;
bool enable_winograd_opt = 3;
bool use_low_power = 4;
bool prefer_texture_weights = 5;
bool enable_host_mapped_pointer = 6;
}
GpuModelInfo gpu_model_info = 10;
// Use LlmParameters instead.
reserved 11;
int32 num_decode_tokens = 12;
// Use lora_ranks instead.
reserved 13;
int32 sequence_batch_size = 14;
string lora_path = 19;
odml.infra.proto.LlmParameters llm_parameters = 20;
// Each output head will generate tokens independently of the others.
int32 num_output_heads = 22;
// The number of candidate tokens to sample from our softmax output in top-k
// sampling.
int32 topk = 23;
// The softmax temperature. For any value less than 1/1024 (the difference
// between 1.0 and the next representable value for half-precision floats),
// the sampling op collapses to an ArgMax.
float temperature = 24;
// Random seed for sampling tokens.
optional uint32 random_seed = 25;
reserved 26, 27;
}

View File

@ -0,0 +1,53 @@
// Copyright 2024 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package odml.infra.proto;
option java_package = "com.google.odml.infra.proto";
option java_outer_classname = "TokenizerCalculatorOptionsProto";
message TokenizerCalculatorOptions {
// The maximum number of tokens for the calculator's BERT model. Used
// if the model's input tensors have static shape.
int32 max_tokens = 1;
message TfLiteModelFile {
// (Optional) The path to the tflite model file, whose metadata (look up by
// the key below) contains the buffer to initialize tokenizer. If empty,
// the model should be provided through e.g. side input.
string model_file = 1;
// The key in the tflite model metadata field, whose corresponding data is
// a sentence piece model. If not provided, "spm_vocab_model" will be used.
string spm_model_key_in_metadata = 2;
}
oneof model_file {
// The path to the SentencePiece model file.
string spm_model_file = 2;
// Options to load the model from tflite model's metadata
TfLiteModelFile tflite_model_file = 4;
}
// The start token id to be prepended to the input prompt token ids. Note that
// this should match the setting used while training the model.
int32 start_token_id = 3;
// Whether to convert string bytes to unicode before tokenization. This
// mapping is used in GPT2 like tokenizer.
bool bytes_to_unicode_mapping = 5;
}

View File

@ -31,3 +31,9 @@ mediapipe_proto_library(
srcs = ["llm_params.proto"],
deps = [":transformer_params_proto"],
)
mediapipe_proto_library(
name = "llm_file_metadata_proto",
srcs = ["llm_file_metadata.proto"],
deps = [":llm_params_proto"],
)

View File

@ -0,0 +1,54 @@
// Copyright 2024 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package odml.infra.proto;
import "mediapipe/tasks/cc/genai/inference/proto/llm_params.proto";
option java_package = "com.google.odml.infra.proto";
option java_outer_classname = "LlmFileMetadataProto";
// Metadata for an LLM file loaded by `GpuOptimizedTensorLoader`.
message LlmFileMetadata {
message TensorInfo {
// The name of the tensor, e.g. "params.lm.softmax.logits_ffn.linear.w".
string name = 1;
// The offset of the tensor data in the file in bytes.
uint64 offset = 2;
// The size of the tensor data in bytes.
uint64 size = 3;
// In some cases, it's insufficient to use `size` to determine the data type
// e.g. 4 bit could be INT4 or UINT4. Then `data_type` will give more
// precise information.
enum DataType {
UNSPECIFIED = 0;
FLOAT32 = 1;
INT8 = 2;
INT4 = 3;
UINT4 = 4;
}
DataType data_type = 4;
}
repeated TensorInfo tensors = 1;
odml.infra.proto.LlmParameters model_params = 2;
// The LoRA rank if this is a set of LoRA weights.
int32 lora_rank = 3;
}

View File

@ -0,0 +1,119 @@
# This contains the MediaPipe GenAI Tasks.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_files(srcs = [
"wasm/genai_wasm_internal.js",
"wasm/genai_wasm_internal.wasm",
"wasm/genai_wasm_nosimd_internal.js",
"wasm/genai_wasm_nosimd_internal.wasm",
])
GENAI_LIBS = [
"//mediapipe/tasks/web/core:fileset_resolver",
"//mediapipe/tasks/web/genai/llm_inference",
]
mediapipe_ts_library(
name = "genai_lib",
srcs = ["index.ts"],
visibility = ["//visibility:public"],
deps = GENAI_LIBS,
)
mediapipe_ts_library(
name = "genai_types",
srcs = ["types.ts"],
visibility = ["//visibility:public"],
deps = GENAI_LIBS,
)
rollup_bundle(
name = "genai_bundle_mjs",
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
entry_point = "index.ts",
format = "esm",
output_dir = False,
sourcemap = "true",
deps = [
":genai_lib",
"@npm//@rollup/plugin-commonjs",
"@npm//@rollup/plugin-node-resolve",
"@npm//@rollup/plugin-terser",
"@npm//google-protobuf",
],
)
rollup_bundle(
name = "genai_bundle_cjs",
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
entry_point = "index.ts",
format = "cjs",
output_dir = False,
sourcemap = "true",
deps = [
":genai_lib",
"@npm//@rollup/plugin-commonjs",
"@npm//@rollup/plugin-node-resolve",
"@npm//@rollup/plugin-terser",
"@npm//google-protobuf",
],
)
genrule(
name = "genai_sources",
srcs = [
":genai_bundle_cjs",
":genai_bundle_mjs",
],
outs = [
"genai_bundle.cjs",
"genai_bundle.cjs.map",
"genai_bundle.mjs",
"genai_bundle.mjs.map",
],
cmd = (
"for FILE in $(SRCS); do " +
" OUT_FILE=$(GENDIR)/mediapipe/tasks/web/genai/$$(" +
" basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
" ); " +
" echo $$FILE ; echo $$OUT_FILE ; " +
" cp $$FILE $$OUT_FILE ; " +
"done;"
),
)
genrule(
name = "package_json",
srcs = ["//mediapipe/tasks/web:package.json"],
outs = ["package.json"],
cmd = "cp $< $@",
)
pkg_npm(
name = "genai_pkg",
package_name = "@mediapipe/tasks-__NAME__",
srcs = ["README.md"],
substitutions = {
"__NAME__": "genai",
"__DESCRIPTION__": "MediaPipe GenAI Tasks",
"__TYPES__": "genai.d.ts",
},
tgz = "genai.tgz",
deps = [
"wasm/genai_wasm_internal.js",
"wasm/genai_wasm_internal.wasm",
"wasm/genai_wasm_nosimd_internal.js",
"wasm/genai_wasm_nosimd_internal.wasm",
":genai_sources",
":package_json",
],
)

View File

@ -0,0 +1,17 @@
# MediaPipe Tasks GenAI Package
This package contains the GenAI tasks for MediaPipe.
## LLM Inference
The MediaPipe LLM Inference task generates text response from input text.
```
const genai = await FilesetResolver.forGenAiTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/wasm"
);
const llmInference = await LlmInference.createFromModelPath(genai, MODEL_URL);
const response = llmInference.generateResponse(inputText);
```
<!-- TODO: Complete README for MediaPipe GenAI Task. -->

View File

@ -0,0 +1,25 @@
/**
* Copyright 2024 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {LlmInference as LlmInferenceImpl} from '../../../tasks/web/genai/llm_inference/llm_inference';
// Declare the variables locally so that Rollup in OSS includes them explicitly
// as exports.
const FilesetResolver = FilesetResolverImpl;
const LlmInference = LlmInferenceImpl;
export {FilesetResolver, LlmInference};

View File

@ -0,0 +1,43 @@
# This contains the MediaPipe LLM Inference Task.
#
# This task takes text input and performs text generation
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "llm_inference",
srcs = ["llm_inference.ts"],
visibility = ["//visibility:public"],
deps = [
":llm_inference_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/calculators:detokenizer_calculator_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/calculators:llm_gpu_calculator_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/calculators:tokenizer_calculator_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/proto:llm_params_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/proto:transformer_params_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/tasks/web/genai/llm_inference/proto:llm_inference_graph_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner/internal:graph_runner_wasm_file_reference_ts",
"//mediapipe/web/graph_runner/internal:graph_runner_webgpu_ts",
],
)
mediapipe_ts_declaration(
name = "llm_inference_types",
srcs = [
"llm_inference_options.d.ts",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/web/core",
],
)

View File

@ -0,0 +1,435 @@
/**
* Copyright 2024 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Any} from 'google-protobuf/google/protobuf/any_pb';
import {CalculatorGraphConfig, InputStreamInfo,} from '../../../../framework/calculator_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {CachedGraphRunner, TaskRunner,} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {LlmInferenceGraphOptions as LlmInferenceGraphOptionsProto} from '../../../../tasks/web/genai/llm_inference/proto/llm_inference_graph_options_pb';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
import {SupportWasmFileReference, WasmFileReference} from '../../../../web/graph_runner/internal/graph_runner_wasm_file_reference';
import {SupportWebGpu} from '../../../../web/graph_runner/internal/graph_runner_webgpu';
import {DetokenizerCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/detokenizer_calculator_pb';
import {LlmGpuCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/llm_gpu_calculator_pb';
import {TokenizerCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/tokenizer_calculator_pb';
import {LlmParameters} from '../../../../tasks/cc/genai/inference/proto/llm_params_pb';
import {TransformerParameters} from '../../../../tasks/cc/genai/inference/proto/transformer_params_pb';
import {LlmInferenceOptions} from './llm_inference_options';
export * from './llm_inference_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
// TODO: b/327515383 - Use ReturnType patter to apply extensions to LLM Web API.
// tslint:disable-next-line:enforce-name-casing
const WasmFileReferenceWebGpuGraphRunnerType =
SupportWebGpu(SupportWasmFileReference(CachedGraphRunner));
class WasmFileReferenceWebGpuGraphRunner extends
WasmFileReferenceWebGpuGraphRunnerType {}
/**
* A callback that receives the result from the LLM Inference.
*/
export type LlmInferenceCallback = (result: string[]) => void;
const INPUT_STREAM = 'text_in';
const OUTPUT_STREAM = 'text_out';
const OUTPUT_END_STREAM = 'text_end';
/**
* Performs LLM Inference on text.
*/
export class LlmInference extends TaskRunner {
private static readonly TOKEN_SPLITTER =
'▁'; // Note this is NOT an underscore: ▁(U+2581)
private static readonly NEW_LINE = '<0x0A>';
private static readonly EOD = '\\[eod\\]';
private static readonly LLM_MODEL_NAME = 'llm.tflite';
private static readonly TOKENIZER_MODE_IN_TFLITE_KEY = 'spm_vocab_model';
private readonly generationResult: string[] = [];
private readonly options = new LlmInferenceGraphOptionsProto();
private isProcessing = false;
private resolveGeneration?: (result: string[]) => void;
private userCallback: LlmInferenceCallback = (result: string[]) => {};
private chunkGenerationCallback = (result: string) => {};
private wasmFileReference?: WasmFileReference;
/**
* Initializes the Wasm runtime and creates a new llm inference from the
* provided options.
* @export
* @param wasmFileset A configuration object that provides the location of the
* Wasm binary and its loader.
* @param llmInferenceOptions The options for the LLM Inference. Note that
* either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmFileset: WasmFileset,
llmInferenceOptions: LlmInferenceOptions): Promise<LlmInference> {
// TODO: b/324482487 - Support customizing config for Web task of LLM
// Inference.
const optionsWithGpuDevice = llmInferenceOptions;
if (!optionsWithGpuDevice.baseOptions?.gpuOptions?.device) {
const deviceDescriptor: GPUDeviceDescriptor = {
requiredFeatures: ['shader-f16'],
requiredLimits: {
'maxStorageBufferBindingSize': 524550144,
'maxBufferSize': 524550144,
},
};
const webgpuDevice =
await WasmFileReferenceWebGpuGraphRunner.requestWebGpuDevice(
deviceDescriptor);
optionsWithGpuDevice.baseOptions = llmInferenceOptions.baseOptions ?? {};
optionsWithGpuDevice.baseOptions.gpuOptions =
llmInferenceOptions?.baseOptions?.gpuOptions ?? {};
optionsWithGpuDevice.baseOptions.gpuOptions.device = webgpuDevice;
}
return TaskRunner.createInstance(
LlmInference, /* canvas= */ null, wasmFileset, optionsWithGpuDevice);
}
/** @hideconstructor */
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(new WasmFileReferenceWebGpuGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}
// TODO: b/325936012 - Move setChunkGeneration to LLM Inference Task option.
/**
* When LLM Inference have new tokens generated, the callback will be called
* with a string of these new tokens.
*
* @param callback The callback that is invoked with the newly generated
* tokens.
*/
setChunkGenerationCallback(callback: (result: string) => void) {
this.chunkGenerationCallback = callback;
}
/**
* Sets new options for the llm inference.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @export
* @param options The options for the llm inference.
*/
override setOptions(options: LlmInferenceOptions): Promise<void> {
// TODO: b/324482487 - Support customizing config for Web task of LLM
// Inference.
if (this.wasmFileReference) {
this.wasmFileReference.free();
}
if (options.baseOptions?.gpuOptions?.device) {
(this.graphRunner as unknown as WasmFileReferenceWebGpuGraphRunner)
.initializeForWebGpu(options.baseOptions.gpuOptions.device);
}
if (options?.baseOptions?.modelAssetPath) {
return WasmFileReference
.loadFromUrl(
this.graphRunner.wasmModule, options.baseOptions.modelAssetPath)
.then((wasmFileReference: WasmFileReference) => {
this.wasmFileReference = wasmFileReference;
this.refreshGraph();
this.onGraphRefreshed();
});
} else if (options?.baseOptions?.modelAssetBuffer) {
this.wasmFileReference = WasmFileReference.loadFromArray(
this.graphRunner.wasmModule, options.baseOptions.modelAssetBuffer);
this.refreshGraph();
this.onGraphRefreshed();
}
return Promise.resolve();
}
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/**
* Decodes the response from the LLM engine and returns a human-readable
* string.
*/
static decodeResponse(responses: string[], stripLeadingWhitespace: boolean):
string {
if (responses == null || responses.length === 0) {
// Technically, this is an error. We should always get at least one
// response.
return '';
}
let response = responses[0]; // We only use the first response
response = response.replaceAll(LlmInference.TOKEN_SPLITTER, ' ');
response = response.replaceAll(
LlmInference.NEW_LINE, '\n'); // Replace <0x0A> token with newline
if (stripLeadingWhitespace) {
response = response.trimStart();
}
return response.split(LlmInference.EOD, 1)[0];
}
/**
* Performs llm inference on the provided text and waits synchronously
* for the response.
*
* @export
* @param text The text to process.
* @param callback The callback that is invoked with the result.
* @return The generated text resuls.
*/
generateResponse(text: string, callback: LlmInferenceCallback): void {
if (this.isProcessing) {
throw new Error('Previous invocation is still processing.');
}
this.generationResult.length = 0;
this.userCallback = callback;
this.isProcessing = true;
this.graphRunner.addStringToStream(
text, INPUT_STREAM, this.getSynctheticTimestamp());
this.finishProcessing();
}
/**
* Performs llm inference on the provided text and waits synchronously
* for the response.
*
* @export
* @param text The text to process.
* @return The generated text resuls.
*/
generateResponseAsync(text: string): Promise<string[]> {
if (this.isProcessing) {
throw new Error('Previous invocation is still processing.');
}
this.generationResult.length = 0;
this.isProcessing = true;
this.graphRunner.addStringToStream(
text, INPUT_STREAM, this.getSynctheticTimestamp());
this.finishProcessing();
return new Promise<string[]>((resolve, reject) => {
this.resolveGeneration = resolve;
});
}
// TODO: b/324919242 - Add sync API for BYOM Web API when Chrome JSPI is
// available
/** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void {
const graphConfig = this.buildLlmInferenceGraph();
this.graphRunner.attachStringVectorListener(
OUTPUT_STREAM, (stringVector, timestamp) => {
const stripLeadingWhitespace = this.generationResult.length === 0;
const decodedText =
LlmInference.decodeResponse(stringVector, stripLeadingWhitespace);
this.generationResult.push(decodedText);
this.chunkGenerationCallback(decodedText);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(OUTPUT_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachBoolListener(
OUTPUT_END_STREAM, (bool, timestamp) => {
this.isProcessing = false;
if (this.resolveGeneration) {
this.resolveGeneration(this.generationResult);
}
this.userCallback(this.generationResult);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(OUTPUT_END_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
if (this.wasmFileReference) {
(this.graphRunner as unknown as WasmFileReferenceWebGpuGraphRunner)
.addWasmFileReferenceToInputSidePacket(
this.wasmFileReference,
'model_file_reference',
);
}
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
private buildLlmInferenceGraph(): CalculatorGraphConfig {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addInputSidePacket('model_file_reference');
graphConfig.addOutputStream(OUTPUT_STREAM);
graphConfig.addOutputStream(OUTPUT_END_STREAM);
// TokenizerInputBuilder Node
const tokenizerInputBuildNode = new CalculatorGraphConfig.Node();
tokenizerInputBuildNode.setCalculator('TokenizerInputBuildCalculator');
tokenizerInputBuildNode.addInputStream(INPUT_STREAM);
tokenizerInputBuildNode.addOutputStream('prompt');
graphConfig.addNode(tokenizerInputBuildNode);
// TFLite model Node
const tfliteModelNode = new CalculatorGraphConfig.Node();
tfliteModelNode.setCalculator('TfLiteModelCalculator');
tfliteModelNode.addInputSidePacket(
'MODEL_SPAN:' +
'model_file_reference');
tfliteModelNode.addOutputSidePacket(
'SHARED_MODEL:' +
'__side_packet_0');
graphConfig.addNode(tfliteModelNode);
// Tokenizer Node
const tokenizerOptionsProto = new Any();
tokenizerOptionsProto.setTypeUrl(
'type.googleapis.com/odml.infra.proto.TokenizerCalculatorOptions');
const tokenizerOptions = new TokenizerCalculatorOptions();
tokenizerOptions.setMaxTokens(512);
const modelFile = new TokenizerCalculatorOptions.TfLiteModelFile();
modelFile.setSpmModelKeyInMetadata(
LlmInference.TOKENIZER_MODE_IN_TFLITE_KEY);
tokenizerOptions.setTfliteModelFile(modelFile);
tokenizerOptions.setStartTokenId(2);
tokenizerOptionsProto.setValue(tokenizerOptions.serializeBinary());
const tokenizerNode = new CalculatorGraphConfig.Node();
tokenizerNode.setCalculator('TokenizerCalculator');
tokenizerNode.addNodeOptions(tokenizerOptionsProto);
tokenizerNode.addInputStream(
'PROMPT:' +
'prompt');
tokenizerNode.addOutputSidePacket(
'PROCESSOR:' +
'__input_side_1');
tokenizerNode.addInputSidePacket(
'TFLITE_MODEL:' +
'__side_packet_0');
tokenizerNode.addOutputStream(
'IDS:' +
'__stream_0');
graphConfig.addNode(tokenizerNode);
// LlmGpu Node
const llmGpuOptionsProto = new Any();
llmGpuOptionsProto.setTypeUrl(
'type.googleapis.com/odml.infra.proto.LlmGpuCalculatorOptions');
const llmGpuOptions = new LlmGpuCalculatorOptions();
llmGpuOptions.setNumDecodeTokens(3);
llmGpuOptions.setWeightPath(LlmInference.LLM_MODEL_NAME);
llmGpuOptions.setSequenceBatchSize(0);
llmGpuOptions.setNumOutputHeads(1);
llmGpuOptions.setTopk(1);
llmGpuOptions.setTemperature(1.0);
const gpuModelInfo = new LlmGpuCalculatorOptions.GpuModelInfo();
gpuModelInfo.setAllowPrecisionLoss(true);
gpuModelInfo.setEnableFastTuning(true);
gpuModelInfo.setPreferTextureWeights(true);
llmGpuOptions.setGpuModelInfo(gpuModelInfo);
const llmParams = new LlmParameters();
const transformerParams = new TransformerParameters();
transformerParams.setBatchSize(1);
transformerParams.setMaxSeqLength(512);
llmParams.setTransformerParameters(transformerParams);
llmGpuOptions.setLlmParameters(llmParams);
llmGpuOptionsProto.setValue(llmGpuOptions.serializeBinary());
const llmGpuNode = new CalculatorGraphConfig.Node();
llmGpuNode.setCalculator('LlmGpuCalculator');
llmGpuNode.addNodeOptions(llmGpuOptionsProto);
llmGpuNode.addInputStream(
'INPUT_PROMPT_IDS:' +
'__stream_0');
llmGpuNode.addInputStream(
'FINISH:' +
'finish');
llmGpuNode.addInputSidePacket(
'SHARED_MODEL:' +
'__side_packet_0');
llmGpuNode.addOutputStream(
'DECODED_IDS:' +
'__stream_3');
llmGpuNode.addOutputStream(
'OUTPUT_END:' +
'__stream_4');
const backEdgeInputStreamInfo = new InputStreamInfo();
backEdgeInputStreamInfo.setTagIndex('FINISH');
backEdgeInputStreamInfo.setBackEdge(true);
llmGpuNode.addInputStreamInfo(backEdgeInputStreamInfo);
graphConfig.addNode(llmGpuNode);
const isPacketPresentNode = new CalculatorGraphConfig.Node();
isPacketPresentNode.setCalculator('IsPacketPresentCalculator');
isPacketPresentNode.addInputStream('__stream_4');
isPacketPresentNode.addOutputStream(OUTPUT_END_STREAM);
graphConfig.addNode(isPacketPresentNode);
// Detokenizer Node
const detokenizerOptionsProto = new Any();
detokenizerOptionsProto.setTypeUrl(
'type.googleapis.com/odml.infra.proto.DetokenizerCalculatorOptions');
const detokenizerOptions = new DetokenizerCalculatorOptions();
detokenizerOptions.setNumOutputHeads(1);
// No need to set spm model, instead reuse TokenizerCalculator's side input.
detokenizerOptions.addStopTokens('<eos>');
detokenizerOptionsProto.setValue(detokenizerOptions.serializeBinary());
const detokenizerNode = new CalculatorGraphConfig.Node();
detokenizerNode.setCalculator('DetokenizerCalculator');
detokenizerNode.addNodeOptions(detokenizerOptionsProto);
detokenizerNode.addInputStream(
'IDS:' +
'__stream_3');
detokenizerNode.addInputSidePacket(
'PROCESSOR:' +
'__input_side_1');
detokenizerNode.addOutputStream('FINISH:finish');
detokenizerNode.addOutputStream('WORDS:' + OUTPUT_STREAM);
graphConfig.addNode(detokenizerNode);
return graphConfig;
}
override close() {
// TODO: b/327307061 - Release tflite file in Wasm heap at the earliest
// point
if (this.wasmFileReference) {
this.wasmFileReference.free();
}
super.close();
}
}

View File

@ -0,0 +1,40 @@
/**
* Copyright 2024 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions, TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/**
* Options to configure the WebGPU device for LLM Inference task.
*/
export declare interface WebGpuOptions {
device?: GPUDevice;
// TODO: b/327685206 - Fill Adapter infor for LLM Web task
adapterInfo?: GPUAdapterInfo;
}
/**
* Options to configure the model loading and processing for LLM Inference task.
*/
export declare interface LlmBaseOptions extends BaseOptions {
gpuOptions?: WebGpuOptions;
}
// TODO: b/324482487 - Support customizing config for Web task of LLM Inference.
/** Options to configure the MediaPipe LLM Inference Task */
export declare interface LlmInferenceOptions extends TaskRunnerOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: LlmBaseOptions;
}

View File

@ -0,0 +1,30 @@
/**
* Copyright 2024 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {LlmInference} from './llm_inference';
describe('LlmInference', () => {
it('decoding LLM Inference output succeeds', () => {
expect(LlmInference.decodeResponse([' strip leading'], true))
.toBe('strip leading');
expect(LlmInference.decodeResponse(['_replace_underscore_'], false))
.toBe(' replace underscore ');
expect(LlmInference.decodeResponse(['replace_newline<0x0A>'], false))
.toBe(' replace newline\n');
});
});

View File

@ -0,0 +1,29 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "llm_inference_graph_options_proto",
srcs = ["llm_inference_graph_options.proto"],
deps = [
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,32 @@
/* Copyright 2024 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.genai.llm_inference.proto;
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.genai.llminference.proto";
option java_outer_classname = "LlmInferenceGraphOptionsProto";
message LlmInferenceGraphOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// asset bundle file with metadata, accelerator options, etc.
core.proto.BaseOptions base_options = 1;
// TODO: b/324482487 - Support customizing config for Web task of LLM
// Inference.
}

View File

@ -0,0 +1,18 @@
/**
* Copyright 2024 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/genai/llm_inference/llm_inference';

View File

@ -13,6 +13,7 @@
"@types/jasmine": "^4.3.1",
"@types/node": "^18.11.11",
"@types/offscreencanvas": "^2019.7.0",
"@webgpu/types": "^0.1.40",
"google-protobuf": "^3.21.2",
"jasmine": "^4.5.0",
"jasmine-core": "^4.5.0",

View File

@ -2,7 +2,7 @@
"compilerOptions": {
"target": "es2017",
"module": "commonjs",
"lib": ["ES2017", "dom"],
"lib": ["ES2021", "dom"],
"declaration": true,
"moduleResolution": "node",
"esModuleInterop": true,
@ -10,7 +10,7 @@
"inlineSourceMap": true,
"inlineSources": true,
"strict": true,
"types": ["@types/offscreencanvas", "@types/jasmine", "node"],
"types": ["@types/offscreencanvas", "@types/jasmine", "@webgpu/types", "node"],
"rootDirs": [
".",
"./bazel-out/host/bin",

1760
yarn.lock

File diff suppressed because it is too large Load Diff