mirror of
https://github.com/google-ai-edge/mediapipe.git
synced 2024-08-21 00:08:10 +08:00
No public description
PiperOrigin-RevId: 611791844
This commit is contained in:
parent
faacda8134
commit
81b81b045b
@ -52,6 +52,7 @@ def mediapipe_ts_library(
|
||||
"@npm//@types/node",
|
||||
"@npm//@types/offscreencanvas",
|
||||
"@npm//@types/google-protobuf",
|
||||
"@npm//@webgpu/types",
|
||||
],
|
||||
testonly = testonly,
|
||||
declaration = True,
|
||||
|
40
mediapipe/tasks/cc/genai/inference/calculators/BUILD
Normal file
40
mediapipe/tasks/cc/genai/inference/calculators/BUILD
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "detokenizer_calculator_proto",
|
||||
srcs = ["detokenizer_calculator.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "tokenizer_calculator_proto",
|
||||
srcs = ["tokenizer_calculator.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "llm_gpu_calculator_proto",
|
||||
srcs = ["llm_gpu_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/genai/inference/proto:llm_file_metadata_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/proto:llm_params_proto",
|
||||
],
|
||||
)
|
@ -0,0 +1,34 @@
|
||||
// Copyright 2024 The ODML Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package odml.infra.proto;
|
||||
|
||||
option java_package = "com.google.odml.infra.proto";
|
||||
option java_outer_classname = "DetokenizerCalculatorOptionsProto";
|
||||
|
||||
message DetokenizerCalculatorOptions {
|
||||
// The path to the SentencePiece model file.
|
||||
string spm_model_file = 1;
|
||||
|
||||
// A set of tokens to stop the decoding process whenever they appear in the
|
||||
// result.
|
||||
repeated string stop_tokens = 4;
|
||||
|
||||
// How many sets of input IDs need to be detokenized.
|
||||
int32 num_output_heads = 5;
|
||||
|
||||
reserved 2, 3;
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
// Copyright 2024 The ODML Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package odml.infra.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/genai/inference/proto/llm_file_metadata.proto";
|
||||
import "mediapipe/tasks/cc/genai/inference/proto/llm_params.proto";
|
||||
|
||||
option java_package = "com.google.odml.infra.proto";
|
||||
option java_outer_classname = "LlmGpuCalculatorOptionsProto";
|
||||
|
||||
message LlmGpuCalculatorOptions {
|
||||
string weight_path = 1;
|
||||
|
||||
// Use LlmParameters instead.
|
||||
reserved 2, 3, 4, 5, 6, 7, 8, 9;
|
||||
|
||||
message GpuModelInfo {
|
||||
// If set True, use float16 precision in computation.
|
||||
bool allow_precision_loss = 1;
|
||||
bool enable_fast_tuning = 2;
|
||||
bool enable_winograd_opt = 3;
|
||||
bool use_low_power = 4;
|
||||
bool prefer_texture_weights = 5;
|
||||
bool enable_host_mapped_pointer = 6;
|
||||
}
|
||||
GpuModelInfo gpu_model_info = 10;
|
||||
|
||||
// Use LlmParameters instead.
|
||||
reserved 11;
|
||||
|
||||
int32 num_decode_tokens = 12;
|
||||
|
||||
// Use lora_ranks instead.
|
||||
reserved 13;
|
||||
|
||||
int32 sequence_batch_size = 14;
|
||||
|
||||
string lora_path = 19;
|
||||
|
||||
odml.infra.proto.LlmParameters llm_parameters = 20;
|
||||
|
||||
// Each output head will generate tokens independently of the others.
|
||||
int32 num_output_heads = 22;
|
||||
|
||||
// The number of candidate tokens to sample from our softmax output in top-k
|
||||
// sampling.
|
||||
int32 topk = 23;
|
||||
|
||||
// The softmax temperature. For any value less than 1/1024 (the difference
|
||||
// between 1.0 and the next representable value for half-precision floats),
|
||||
// the sampling op collapses to an ArgMax.
|
||||
float temperature = 24;
|
||||
|
||||
// Random seed for sampling tokens.
|
||||
optional uint32 random_seed = 25;
|
||||
|
||||
reserved 26, 27;
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
// Copyright 2024 The ODML Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package odml.infra.proto;
|
||||
|
||||
option java_package = "com.google.odml.infra.proto";
|
||||
option java_outer_classname = "TokenizerCalculatorOptionsProto";
|
||||
|
||||
message TokenizerCalculatorOptions {
|
||||
// The maximum number of tokens for the calculator's BERT model. Used
|
||||
// if the model's input tensors have static shape.
|
||||
int32 max_tokens = 1;
|
||||
|
||||
message TfLiteModelFile {
|
||||
// (Optional) The path to the tflite model file, whose metadata (look up by
|
||||
// the key below) contains the buffer to initialize tokenizer. If empty,
|
||||
// the model should be provided through e.g. side input.
|
||||
string model_file = 1;
|
||||
|
||||
// The key in the tflite model metadata field, whose corresponding data is
|
||||
// a sentence piece model. If not provided, "spm_vocab_model" will be used.
|
||||
string spm_model_key_in_metadata = 2;
|
||||
}
|
||||
|
||||
oneof model_file {
|
||||
// The path to the SentencePiece model file.
|
||||
string spm_model_file = 2;
|
||||
|
||||
// Options to load the model from tflite model's metadata
|
||||
TfLiteModelFile tflite_model_file = 4;
|
||||
}
|
||||
|
||||
// The start token id to be prepended to the input prompt token ids. Note that
|
||||
// this should match the setting used while training the model.
|
||||
int32 start_token_id = 3;
|
||||
|
||||
// Whether to convert string bytes to unicode before tokenization. This
|
||||
// mapping is used in GPT2 like tokenizer.
|
||||
bool bytes_to_unicode_mapping = 5;
|
||||
}
|
@ -31,3 +31,9 @@ mediapipe_proto_library(
|
||||
srcs = ["llm_params.proto"],
|
||||
deps = [":transformer_params_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "llm_file_metadata_proto",
|
||||
srcs = ["llm_file_metadata.proto"],
|
||||
deps = [":llm_params_proto"],
|
||||
)
|
||||
|
@ -0,0 +1,54 @@
|
||||
// Copyright 2024 The ODML Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package odml.infra.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/genai/inference/proto/llm_params.proto";
|
||||
|
||||
option java_package = "com.google.odml.infra.proto";
|
||||
option java_outer_classname = "LlmFileMetadataProto";
|
||||
|
||||
// Metadata for an LLM file loaded by `GpuOptimizedTensorLoader`.
|
||||
message LlmFileMetadata {
|
||||
message TensorInfo {
|
||||
// The name of the tensor, e.g. "params.lm.softmax.logits_ffn.linear.w".
|
||||
string name = 1;
|
||||
|
||||
// The offset of the tensor data in the file in bytes.
|
||||
uint64 offset = 2;
|
||||
|
||||
// The size of the tensor data in bytes.
|
||||
uint64 size = 3;
|
||||
|
||||
// In some cases, it's insufficient to use `size` to determine the data type
|
||||
// e.g. 4 bit could be INT4 or UINT4. Then `data_type` will give more
|
||||
// precise information.
|
||||
enum DataType {
|
||||
UNSPECIFIED = 0;
|
||||
FLOAT32 = 1;
|
||||
INT8 = 2;
|
||||
INT4 = 3;
|
||||
UINT4 = 4;
|
||||
}
|
||||
DataType data_type = 4;
|
||||
}
|
||||
repeated TensorInfo tensors = 1;
|
||||
|
||||
odml.infra.proto.LlmParameters model_params = 2;
|
||||
|
||||
// The LoRA rank if this is a set of LoRA weights.
|
||||
int32 lora_rank = 3;
|
||||
}
|
119
mediapipe/tasks/web/genai/BUILD
Normal file
119
mediapipe/tasks/web/genai/BUILD
Normal file
@ -0,0 +1,119 @@
|
||||
# This contains the MediaPipe GenAI Tasks.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
|
||||
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
mediapipe_files(srcs = [
|
||||
"wasm/genai_wasm_internal.js",
|
||||
"wasm/genai_wasm_internal.wasm",
|
||||
"wasm/genai_wasm_nosimd_internal.js",
|
||||
"wasm/genai_wasm_nosimd_internal.wasm",
|
||||
])
|
||||
|
||||
GENAI_LIBS = [
|
||||
"//mediapipe/tasks/web/core:fileset_resolver",
|
||||
"//mediapipe/tasks/web/genai/llm_inference",
|
||||
]
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "genai_lib",
|
||||
srcs = ["index.ts"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = GENAI_LIBS,
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "genai_types",
|
||||
srcs = ["types.ts"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = GENAI_LIBS,
|
||||
)
|
||||
|
||||
rollup_bundle(
|
||||
name = "genai_bundle_mjs",
|
||||
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
||||
entry_point = "index.ts",
|
||||
format = "esm",
|
||||
output_dir = False,
|
||||
sourcemap = "true",
|
||||
deps = [
|
||||
":genai_lib",
|
||||
"@npm//@rollup/plugin-commonjs",
|
||||
"@npm//@rollup/plugin-node-resolve",
|
||||
"@npm//@rollup/plugin-terser",
|
||||
"@npm//google-protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
rollup_bundle(
|
||||
name = "genai_bundle_cjs",
|
||||
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
||||
entry_point = "index.ts",
|
||||
format = "cjs",
|
||||
output_dir = False,
|
||||
sourcemap = "true",
|
||||
deps = [
|
||||
":genai_lib",
|
||||
"@npm//@rollup/plugin-commonjs",
|
||||
"@npm//@rollup/plugin-node-resolve",
|
||||
"@npm//@rollup/plugin-terser",
|
||||
"@npm//google-protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "genai_sources",
|
||||
srcs = [
|
||||
":genai_bundle_cjs",
|
||||
":genai_bundle_mjs",
|
||||
],
|
||||
outs = [
|
||||
"genai_bundle.cjs",
|
||||
"genai_bundle.cjs.map",
|
||||
"genai_bundle.mjs",
|
||||
"genai_bundle.mjs.map",
|
||||
],
|
||||
cmd = (
|
||||
"for FILE in $(SRCS); do " +
|
||||
" OUT_FILE=$(GENDIR)/mediapipe/tasks/web/genai/$$(" +
|
||||
" basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
|
||||
" ); " +
|
||||
" echo $$FILE ; echo $$OUT_FILE ; " +
|
||||
" cp $$FILE $$OUT_FILE ; " +
|
||||
"done;"
|
||||
),
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "package_json",
|
||||
srcs = ["//mediapipe/tasks/web:package.json"],
|
||||
outs = ["package.json"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
pkg_npm(
|
||||
name = "genai_pkg",
|
||||
package_name = "@mediapipe/tasks-__NAME__",
|
||||
srcs = ["README.md"],
|
||||
substitutions = {
|
||||
"__NAME__": "genai",
|
||||
"__DESCRIPTION__": "MediaPipe GenAI Tasks",
|
||||
"__TYPES__": "genai.d.ts",
|
||||
},
|
||||
tgz = "genai.tgz",
|
||||
deps = [
|
||||
"wasm/genai_wasm_internal.js",
|
||||
"wasm/genai_wasm_internal.wasm",
|
||||
"wasm/genai_wasm_nosimd_internal.js",
|
||||
"wasm/genai_wasm_nosimd_internal.wasm",
|
||||
":genai_sources",
|
||||
":package_json",
|
||||
],
|
||||
)
|
17
mediapipe/tasks/web/genai/README.md
Normal file
17
mediapipe/tasks/web/genai/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# MediaPipe Tasks GenAI Package
|
||||
|
||||
This package contains the GenAI tasks for MediaPipe.
|
||||
|
||||
## LLM Inference
|
||||
|
||||
The MediaPipe LLM Inference task generates text response from input text.
|
||||
|
||||
```
|
||||
const genai = await FilesetResolver.forGenAiTasks(
|
||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/wasm"
|
||||
);
|
||||
const llmInference = await LlmInference.createFromModelPath(genai, MODEL_URL);
|
||||
const response = llmInference.generateResponse(inputText);
|
||||
```
|
||||
|
||||
<!-- TODO: Complete README for MediaPipe GenAI Task. -->
|
25
mediapipe/tasks/web/genai/index.ts
Normal file
25
mediapipe/tasks/web/genai/index.ts
Normal file
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2024 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||
import {LlmInference as LlmInferenceImpl} from '../../../tasks/web/genai/llm_inference/llm_inference';
|
||||
|
||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||
// as exports.
|
||||
const FilesetResolver = FilesetResolverImpl;
|
||||
const LlmInference = LlmInferenceImpl;
|
||||
|
||||
export {FilesetResolver, LlmInference};
|
43
mediapipe/tasks/web/genai/llm_inference/BUILD
Normal file
43
mediapipe/tasks/web/genai/llm_inference/BUILD
Normal file
@ -0,0 +1,43 @@
|
||||
# This contains the MediaPipe LLM Inference Task.
|
||||
#
|
||||
# This task takes text input and performs text generation
|
||||
#
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "llm_inference",
|
||||
srcs = ["llm_inference.ts"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":llm_inference_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/calculators:detokenizer_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/calculators:llm_gpu_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/calculators:tokenizer_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/proto:llm_params_jspb_proto",
|
||||
"//mediapipe/tasks/cc/genai/inference/proto:transformer_params_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner",
|
||||
"//mediapipe/tasks/web/genai/llm_inference/proto:llm_inference_graph_options_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
"//mediapipe/web/graph_runner/internal:graph_runner_wasm_file_reference_ts",
|
||||
"//mediapipe/web/graph_runner/internal:graph_runner_webgpu_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_declaration(
|
||||
name = "llm_inference_types",
|
||||
srcs = [
|
||||
"llm_inference_options.d.ts",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
],
|
||||
)
|
435
mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Normal file
435
mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Normal file
@ -0,0 +1,435 @@
|
||||
/**
|
||||
* Copyright 2024 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {Any} from 'google-protobuf/google/protobuf/any_pb';
|
||||
import {CalculatorGraphConfig, InputStreamInfo,} from '../../../../framework/calculator_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {CachedGraphRunner, TaskRunner,} from '../../../../tasks/web/core/task_runner';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {LlmInferenceGraphOptions as LlmInferenceGraphOptionsProto} from '../../../../tasks/web/genai/llm_inference/proto/llm_inference_graph_options_pb';
|
||||
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
import {SupportWasmFileReference, WasmFileReference} from '../../../../web/graph_runner/internal/graph_runner_wasm_file_reference';
|
||||
import {SupportWebGpu} from '../../../../web/graph_runner/internal/graph_runner_webgpu';
|
||||
import {DetokenizerCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/detokenizer_calculator_pb';
|
||||
import {LlmGpuCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/llm_gpu_calculator_pb';
|
||||
import {TokenizerCalculatorOptions} from '../../../../tasks/cc/genai/inference/calculators/tokenizer_calculator_pb';
|
||||
import {LlmParameters} from '../../../../tasks/cc/genai/inference/proto/llm_params_pb';
|
||||
import {TransformerParameters} from '../../../../tasks/cc/genai/inference/proto/transformer_params_pb';
|
||||
|
||||
import {LlmInferenceOptions} from './llm_inference_options';
|
||||
|
||||
export * from './llm_inference_options';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
// TODO: b/327515383 - Use ReturnType patter to apply extensions to LLM Web API.
|
||||
// tslint:disable-next-line:enforce-name-casing
|
||||
const WasmFileReferenceWebGpuGraphRunnerType =
|
||||
SupportWebGpu(SupportWasmFileReference(CachedGraphRunner));
|
||||
class WasmFileReferenceWebGpuGraphRunner extends
|
||||
WasmFileReferenceWebGpuGraphRunnerType {}
|
||||
|
||||
/**
|
||||
* A callback that receives the result from the LLM Inference.
|
||||
*/
|
||||
export type LlmInferenceCallback = (result: string[]) => void;
|
||||
|
||||
const INPUT_STREAM = 'text_in';
|
||||
const OUTPUT_STREAM = 'text_out';
|
||||
const OUTPUT_END_STREAM = 'text_end';
|
||||
|
||||
/**
|
||||
* Performs LLM Inference on text.
|
||||
*/
|
||||
export class LlmInference extends TaskRunner {
|
||||
private static readonly TOKEN_SPLITTER =
|
||||
'▁'; // Note this is NOT an underscore: ▁(U+2581)
|
||||
private static readonly NEW_LINE = '<0x0A>';
|
||||
private static readonly EOD = '\\[eod\\]';
|
||||
private static readonly LLM_MODEL_NAME = 'llm.tflite';
|
||||
private static readonly TOKENIZER_MODE_IN_TFLITE_KEY = 'spm_vocab_model';
|
||||
|
||||
private readonly generationResult: string[] = [];
|
||||
private readonly options = new LlmInferenceGraphOptionsProto();
|
||||
private isProcessing = false;
|
||||
private resolveGeneration?: (result: string[]) => void;
|
||||
private userCallback: LlmInferenceCallback = (result: string[]) => {};
|
||||
private chunkGenerationCallback = (result: string) => {};
|
||||
private wasmFileReference?: WasmFileReference;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new llm inference from the
|
||||
* provided options.
|
||||
* @export
|
||||
* @param wasmFileset A configuration object that provides the location of the
|
||||
* Wasm binary and its loader.
|
||||
* @param llmInferenceOptions The options for the LLM Inference. Note that
|
||||
* either a path to the TFLite model or the model itself needs to be
|
||||
* provided (via `baseOptions`).
|
||||
*/
|
||||
static async createFromOptions(
|
||||
wasmFileset: WasmFileset,
|
||||
llmInferenceOptions: LlmInferenceOptions): Promise<LlmInference> {
|
||||
// TODO: b/324482487 - Support customizing config for Web task of LLM
|
||||
// Inference.
|
||||
const optionsWithGpuDevice = llmInferenceOptions;
|
||||
if (!optionsWithGpuDevice.baseOptions?.gpuOptions?.device) {
|
||||
const deviceDescriptor: GPUDeviceDescriptor = {
|
||||
requiredFeatures: ['shader-f16'],
|
||||
requiredLimits: {
|
||||
'maxStorageBufferBindingSize': 524550144,
|
||||
'maxBufferSize': 524550144,
|
||||
},
|
||||
};
|
||||
const webgpuDevice =
|
||||
await WasmFileReferenceWebGpuGraphRunner.requestWebGpuDevice(
|
||||
deviceDescriptor);
|
||||
optionsWithGpuDevice.baseOptions = llmInferenceOptions.baseOptions ?? {};
|
||||
optionsWithGpuDevice.baseOptions.gpuOptions =
|
||||
llmInferenceOptions?.baseOptions?.gpuOptions ?? {};
|
||||
optionsWithGpuDevice.baseOptions.gpuOptions.device = webgpuDevice;
|
||||
}
|
||||
|
||||
return TaskRunner.createInstance(
|
||||
LlmInference, /* canvas= */ null, wasmFileset, optionsWithGpuDevice);
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(new WasmFileReferenceWebGpuGraphRunner(wasmModule, glCanvas));
|
||||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
}
|
||||
|
||||
// TODO: b/325936012 - Move setChunkGeneration to LLM Inference Task option.
|
||||
/**
|
||||
* When LLM Inference have new tokens generated, the callback will be called
|
||||
* with a string of these new tokens.
|
||||
*
|
||||
* @param callback The callback that is invoked with the newly generated
|
||||
* tokens.
|
||||
*/
|
||||
setChunkGenerationCallback(callback: (result: string) => void) {
|
||||
this.chunkGenerationCallback = callback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the llm inference.
|
||||
*
|
||||
* Calling `setOptions()` with a subset of options only affects those options.
|
||||
* You can reset an option back to its default value by explicitly setting it
|
||||
* to `undefined`.
|
||||
*
|
||||
* @export
|
||||
* @param options The options for the llm inference.
|
||||
*/
|
||||
override setOptions(options: LlmInferenceOptions): Promise<void> {
|
||||
// TODO: b/324482487 - Support customizing config for Web task of LLM
|
||||
// Inference.
|
||||
if (this.wasmFileReference) {
|
||||
this.wasmFileReference.free();
|
||||
}
|
||||
if (options.baseOptions?.gpuOptions?.device) {
|
||||
(this.graphRunner as unknown as WasmFileReferenceWebGpuGraphRunner)
|
||||
.initializeForWebGpu(options.baseOptions.gpuOptions.device);
|
||||
}
|
||||
if (options?.baseOptions?.modelAssetPath) {
|
||||
return WasmFileReference
|
||||
.loadFromUrl(
|
||||
this.graphRunner.wasmModule, options.baseOptions.modelAssetPath)
|
||||
.then((wasmFileReference: WasmFileReference) => {
|
||||
this.wasmFileReference = wasmFileReference;
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
});
|
||||
} else if (options?.baseOptions?.modelAssetBuffer) {
|
||||
this.wasmFileReference = WasmFileReference.loadFromArray(
|
||||
this.graphRunner.wasmModule, options.baseOptions.modelAssetBuffer);
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
}
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Decodes the response from the LLM engine and returns a human-readable
|
||||
* string.
|
||||
*/
|
||||
static decodeResponse(responses: string[], stripLeadingWhitespace: boolean):
|
||||
string {
|
||||
if (responses == null || responses.length === 0) {
|
||||
// Technically, this is an error. We should always get at least one
|
||||
// response.
|
||||
return '';
|
||||
}
|
||||
|
||||
let response = responses[0]; // We only use the first response
|
||||
response = response.replaceAll(LlmInference.TOKEN_SPLITTER, ' ');
|
||||
response = response.replaceAll(
|
||||
LlmInference.NEW_LINE, '\n'); // Replace <0x0A> token with newline
|
||||
|
||||
if (stripLeadingWhitespace) {
|
||||
response = response.trimStart();
|
||||
}
|
||||
|
||||
return response.split(LlmInference.EOD, 1)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs llm inference on the provided text and waits synchronously
|
||||
* for the response.
|
||||
*
|
||||
* @export
|
||||
* @param text The text to process.
|
||||
* @param callback The callback that is invoked with the result.
|
||||
* @return The generated text resuls.
|
||||
*/
|
||||
generateResponse(text: string, callback: LlmInferenceCallback): void {
|
||||
if (this.isProcessing) {
|
||||
throw new Error('Previous invocation is still processing.');
|
||||
}
|
||||
this.generationResult.length = 0;
|
||||
this.userCallback = callback;
|
||||
this.isProcessing = true;
|
||||
this.graphRunner.addStringToStream(
|
||||
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||
this.finishProcessing();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs llm inference on the provided text and waits synchronously
|
||||
* for the response.
|
||||
*
|
||||
* @export
|
||||
* @param text The text to process.
|
||||
* @return The generated text resuls.
|
||||
*/
|
||||
generateResponseAsync(text: string): Promise<string[]> {
|
||||
if (this.isProcessing) {
|
||||
throw new Error('Previous invocation is still processing.');
|
||||
}
|
||||
this.generationResult.length = 0;
|
||||
this.isProcessing = true;
|
||||
this.graphRunner.addStringToStream(
|
||||
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||
this.finishProcessing();
|
||||
return new Promise<string[]>((resolve, reject) => {
|
||||
this.resolveGeneration = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: b/324919242 - Add sync API for BYOM Web API when Chrome JSPI is
|
||||
// available
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = this.buildLlmInferenceGraph();
|
||||
|
||||
this.graphRunner.attachStringVectorListener(
|
||||
OUTPUT_STREAM, (stringVector, timestamp) => {
|
||||
const stripLeadingWhitespace = this.generationResult.length === 0;
|
||||
const decodedText =
|
||||
LlmInference.decodeResponse(stringVector, stripLeadingWhitespace);
|
||||
this.generationResult.push(decodedText);
|
||||
this.chunkGenerationCallback(decodedText);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(OUTPUT_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
this.graphRunner.attachBoolListener(
|
||||
OUTPUT_END_STREAM, (bool, timestamp) => {
|
||||
this.isProcessing = false;
|
||||
if (this.resolveGeneration) {
|
||||
this.resolveGeneration(this.generationResult);
|
||||
}
|
||||
this.userCallback(this.generationResult);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(OUTPUT_END_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
if (this.wasmFileReference) {
|
||||
(this.graphRunner as unknown as WasmFileReferenceWebGpuGraphRunner)
|
||||
.addWasmFileReferenceToInputSidePacket(
|
||||
this.wasmFileReference,
|
||||
'model_file_reference',
|
||||
);
|
||||
}
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
||||
private buildLlmInferenceGraph(): CalculatorGraphConfig {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addInputSidePacket('model_file_reference');
|
||||
graphConfig.addOutputStream(OUTPUT_STREAM);
|
||||
graphConfig.addOutputStream(OUTPUT_END_STREAM);
|
||||
|
||||
// TokenizerInputBuilder Node
|
||||
const tokenizerInputBuildNode = new CalculatorGraphConfig.Node();
|
||||
tokenizerInputBuildNode.setCalculator('TokenizerInputBuildCalculator');
|
||||
tokenizerInputBuildNode.addInputStream(INPUT_STREAM);
|
||||
tokenizerInputBuildNode.addOutputStream('prompt');
|
||||
graphConfig.addNode(tokenizerInputBuildNode);
|
||||
|
||||
// TFLite model Node
|
||||
const tfliteModelNode = new CalculatorGraphConfig.Node();
|
||||
tfliteModelNode.setCalculator('TfLiteModelCalculator');
|
||||
tfliteModelNode.addInputSidePacket(
|
||||
'MODEL_SPAN:' +
|
||||
'model_file_reference');
|
||||
tfliteModelNode.addOutputSidePacket(
|
||||
'SHARED_MODEL:' +
|
||||
'__side_packet_0');
|
||||
graphConfig.addNode(tfliteModelNode);
|
||||
|
||||
// Tokenizer Node
|
||||
const tokenizerOptionsProto = new Any();
|
||||
tokenizerOptionsProto.setTypeUrl(
|
||||
'type.googleapis.com/odml.infra.proto.TokenizerCalculatorOptions');
|
||||
const tokenizerOptions = new TokenizerCalculatorOptions();
|
||||
tokenizerOptions.setMaxTokens(512);
|
||||
|
||||
const modelFile = new TokenizerCalculatorOptions.TfLiteModelFile();
|
||||
modelFile.setSpmModelKeyInMetadata(
|
||||
LlmInference.TOKENIZER_MODE_IN_TFLITE_KEY);
|
||||
tokenizerOptions.setTfliteModelFile(modelFile);
|
||||
|
||||
tokenizerOptions.setStartTokenId(2);
|
||||
tokenizerOptionsProto.setValue(tokenizerOptions.serializeBinary());
|
||||
const tokenizerNode = new CalculatorGraphConfig.Node();
|
||||
tokenizerNode.setCalculator('TokenizerCalculator');
|
||||
tokenizerNode.addNodeOptions(tokenizerOptionsProto);
|
||||
tokenizerNode.addInputStream(
|
||||
'PROMPT:' +
|
||||
'prompt');
|
||||
tokenizerNode.addOutputSidePacket(
|
||||
'PROCESSOR:' +
|
||||
'__input_side_1');
|
||||
tokenizerNode.addInputSidePacket(
|
||||
'TFLITE_MODEL:' +
|
||||
'__side_packet_0');
|
||||
tokenizerNode.addOutputStream(
|
||||
'IDS:' +
|
||||
'__stream_0');
|
||||
graphConfig.addNode(tokenizerNode);
|
||||
|
||||
// LlmGpu Node
|
||||
const llmGpuOptionsProto = new Any();
|
||||
llmGpuOptionsProto.setTypeUrl(
|
||||
'type.googleapis.com/odml.infra.proto.LlmGpuCalculatorOptions');
|
||||
const llmGpuOptions = new LlmGpuCalculatorOptions();
|
||||
|
||||
llmGpuOptions.setNumDecodeTokens(3);
|
||||
llmGpuOptions.setWeightPath(LlmInference.LLM_MODEL_NAME);
|
||||
llmGpuOptions.setSequenceBatchSize(0);
|
||||
llmGpuOptions.setNumOutputHeads(1);
|
||||
llmGpuOptions.setTopk(1);
|
||||
llmGpuOptions.setTemperature(1.0);
|
||||
const gpuModelInfo = new LlmGpuCalculatorOptions.GpuModelInfo();
|
||||
gpuModelInfo.setAllowPrecisionLoss(true);
|
||||
gpuModelInfo.setEnableFastTuning(true);
|
||||
gpuModelInfo.setPreferTextureWeights(true);
|
||||
llmGpuOptions.setGpuModelInfo(gpuModelInfo);
|
||||
|
||||
const llmParams = new LlmParameters();
|
||||
const transformerParams = new TransformerParameters();
|
||||
transformerParams.setBatchSize(1);
|
||||
transformerParams.setMaxSeqLength(512);
|
||||
llmParams.setTransformerParameters(transformerParams);
|
||||
llmGpuOptions.setLlmParameters(llmParams);
|
||||
|
||||
llmGpuOptionsProto.setValue(llmGpuOptions.serializeBinary());
|
||||
const llmGpuNode = new CalculatorGraphConfig.Node();
|
||||
llmGpuNode.setCalculator('LlmGpuCalculator');
|
||||
llmGpuNode.addNodeOptions(llmGpuOptionsProto);
|
||||
llmGpuNode.addInputStream(
|
||||
'INPUT_PROMPT_IDS:' +
|
||||
'__stream_0');
|
||||
llmGpuNode.addInputStream(
|
||||
'FINISH:' +
|
||||
'finish');
|
||||
llmGpuNode.addInputSidePacket(
|
||||
'SHARED_MODEL:' +
|
||||
'__side_packet_0');
|
||||
llmGpuNode.addOutputStream(
|
||||
'DECODED_IDS:' +
|
||||
'__stream_3');
|
||||
llmGpuNode.addOutputStream(
|
||||
'OUTPUT_END:' +
|
||||
'__stream_4');
|
||||
const backEdgeInputStreamInfo = new InputStreamInfo();
|
||||
backEdgeInputStreamInfo.setTagIndex('FINISH');
|
||||
backEdgeInputStreamInfo.setBackEdge(true);
|
||||
llmGpuNode.addInputStreamInfo(backEdgeInputStreamInfo);
|
||||
graphConfig.addNode(llmGpuNode);
|
||||
|
||||
const isPacketPresentNode = new CalculatorGraphConfig.Node();
|
||||
isPacketPresentNode.setCalculator('IsPacketPresentCalculator');
|
||||
isPacketPresentNode.addInputStream('__stream_4');
|
||||
isPacketPresentNode.addOutputStream(OUTPUT_END_STREAM);
|
||||
graphConfig.addNode(isPacketPresentNode);
|
||||
|
||||
// Detokenizer Node
|
||||
const detokenizerOptionsProto = new Any();
|
||||
detokenizerOptionsProto.setTypeUrl(
|
||||
'type.googleapis.com/odml.infra.proto.DetokenizerCalculatorOptions');
|
||||
const detokenizerOptions = new DetokenizerCalculatorOptions();
|
||||
detokenizerOptions.setNumOutputHeads(1);
|
||||
// No need to set spm model, instead reuse TokenizerCalculator's side input.
|
||||
detokenizerOptions.addStopTokens('<eos>');
|
||||
detokenizerOptionsProto.setValue(detokenizerOptions.serializeBinary());
|
||||
const detokenizerNode = new CalculatorGraphConfig.Node();
|
||||
detokenizerNode.setCalculator('DetokenizerCalculator');
|
||||
detokenizerNode.addNodeOptions(detokenizerOptionsProto);
|
||||
detokenizerNode.addInputStream(
|
||||
'IDS:' +
|
||||
'__stream_3');
|
||||
detokenizerNode.addInputSidePacket(
|
||||
'PROCESSOR:' +
|
||||
'__input_side_1');
|
||||
detokenizerNode.addOutputStream('FINISH:finish');
|
||||
detokenizerNode.addOutputStream('WORDS:' + OUTPUT_STREAM);
|
||||
graphConfig.addNode(detokenizerNode);
|
||||
return graphConfig;
|
||||
}
|
||||
|
||||
override close() {
|
||||
// TODO: b/327307061 - Release tflite file in Wasm heap at the earliest
|
||||
// point
|
||||
if (this.wasmFileReference) {
|
||||
this.wasmFileReference.free();
|
||||
}
|
||||
super.close();
|
||||
}
|
||||
}
|
||||
|
||||
|
40
mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts
vendored
Normal file
40
mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2024 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {BaseOptions, TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||
|
||||
/**
|
||||
* Options to configure the WebGPU device for LLM Inference task.
|
||||
*/
|
||||
export declare interface WebGpuOptions {
|
||||
device?: GPUDevice;
|
||||
// TODO: b/327685206 - Fill Adapter infor for LLM Web task
|
||||
adapterInfo?: GPUAdapterInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options to configure the model loading and processing for LLM Inference task.
|
||||
*/
|
||||
export declare interface LlmBaseOptions extends BaseOptions {
|
||||
gpuOptions?: WebGpuOptions;
|
||||
}
|
||||
|
||||
// TODO: b/324482487 - Support customizing config for Web task of LLM Inference.
|
||||
/** Options to configure the MediaPipe LLM Inference Task */
|
||||
export declare interface LlmInferenceOptions extends TaskRunnerOptions {
|
||||
/** Options to configure the loading of the model assets. */
|
||||
baseOptions?: LlmBaseOptions;
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
/**
|
||||
* Copyright 2024 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
import {LlmInference} from './llm_inference';
|
||||
|
||||
describe('LlmInference', () => {
|
||||
it('decoding LLM Inference output succeeds', () => {
|
||||
expect(LlmInference.decodeResponse([' strip leading'], true))
|
||||
.toBe('strip leading');
|
||||
expect(LlmInference.decodeResponse(['_replace_underscore_'], false))
|
||||
.toBe(' replace underscore ');
|
||||
expect(LlmInference.decodeResponse(['replace_newline<0x0A>'], false))
|
||||
.toBe(' replace newline\n');
|
||||
});
|
||||
});
|
29
mediapipe/tasks/web/genai/llm_inference/proto/BUILD
Normal file
29
mediapipe/tasks/web/genai/llm_inference/proto/BUILD
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "llm_inference_graph_options_proto",
|
||||
srcs = ["llm_inference_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2024 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.genai.llm_inference.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.genai.llminference.proto";
|
||||
option java_outer_classname = "LlmInferenceGraphOptionsProto";
|
||||
|
||||
message LlmInferenceGraphOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// asset bundle file with metadata, accelerator options, etc.
|
||||
core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// TODO: b/324482487 - Support customizing config for Web task of LLM
|
||||
// Inference.
|
||||
}
|
18
mediapipe/tasks/web/genai/types.ts
Normal file
18
mediapipe/tasks/web/genai/types.ts
Normal file
@ -0,0 +1,18 @@
|
||||
/**
|
||||
* Copyright 2024 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
export * from '../../../tasks/web/core/fileset_resolver';
|
||||
export * from '../../../tasks/web/genai/llm_inference/llm_inference';
|
@ -13,6 +13,7 @@
|
||||
"@types/jasmine": "^4.3.1",
|
||||
"@types/node": "^18.11.11",
|
||||
"@types/offscreencanvas": "^2019.7.0",
|
||||
"@webgpu/types": "^0.1.40",
|
||||
"google-protobuf": "^3.21.2",
|
||||
"jasmine": "^4.5.0",
|
||||
"jasmine-core": "^4.5.0",
|
||||
|
@ -2,7 +2,7 @@
|
||||
"compilerOptions": {
|
||||
"target": "es2017",
|
||||
"module": "commonjs",
|
||||
"lib": ["ES2017", "dom"],
|
||||
"lib": ["ES2021", "dom"],
|
||||
"declaration": true,
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
@ -10,7 +10,7 @@
|
||||
"inlineSourceMap": true,
|
||||
"inlineSources": true,
|
||||
"strict": true,
|
||||
"types": ["@types/offscreencanvas", "@types/jasmine", "node"],
|
||||
"types": ["@types/offscreencanvas", "@types/jasmine", "@webgpu/types", "node"],
|
||||
"rootDirs": [
|
||||
".",
|
||||
"./bazel-out/host/bin",
|
||||
|
Loading…
Reference in New Issue
Block a user