|
using System; |
|
using System.Collections.Generic; |
|
using System.Linq; |
|
using Unity.Sentis; |
|
using UnityEngine; |
|
|
|
public sealed class DebertaV3 : MonoBehaviour |
|
{ |
|
public ModelAsset model; |
|
public TextAsset vocabulary; |
|
public bool multipleTrueClasses; |
|
public string text = "Angela Merkel is a politician in Germany and leader of the CDU"; |
|
public string hypothesisTemplate = "This example is about {}"; |
|
public string[] classes = { "politics", "economy", "entertainment", "environment" }; |
|
|
|
IWorker engine; |
|
string[] vocabularyTokens; |
|
|
|
const int padToken = 0; |
|
const int startToken = 1; |
|
const int separatorToken = 2; |
|
const int vocabToTokenOffset = 260; |
|
|
|
void Start() |
|
{ |
|
if (classes.Length == 0) |
|
{ |
|
Debug.LogError("There need to be more than 0 classes"); |
|
return; |
|
} |
|
|
|
vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n"); |
|
|
|
Model baseModel = ModelLoader.Load(model); |
|
Model modelWithScoring = Functional.Compile( |
|
input => |
|
{ |
|
|
|
|
|
|
|
FunctionalTensor logits = baseModel.Forward(input)[0]; |
|
|
|
if (multipleTrueClasses || classes.Length == 1) |
|
{ |
|
|
|
logits = Functional.Softmax(logits); |
|
} |
|
else |
|
{ |
|
|
|
logits = Functional.Softmax(logits, 0); |
|
} |
|
|
|
|
|
return new []{logits[.., 0]}; |
|
}, |
|
InputDef.FromModel(baseModel) |
|
); |
|
|
|
engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring); |
|
|
|
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray(); |
|
Batch batch = GetTokenizedBatch(text, hypotheses); |
|
float[] scores = GetBatchScores(batch); |
|
|
|
for (int i = 0; i < scores.Length; i++) |
|
{ |
|
Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}"); |
|
} |
|
} |
|
|
|
float[] GetBatchScores(Batch batch) |
|
{ |
|
using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens); |
|
using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks); |
|
|
|
Dictionary<string, Tensor> inputs = new() |
|
{ |
|
{"input_0", inputIds}, |
|
{"input_1", attentionMask} |
|
}; |
|
|
|
engine.Execute(inputs); |
|
TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0"); |
|
scores.CompleteOperationsAndDownload(); |
|
|
|
return scores.ToReadOnlyArray(); |
|
} |
|
|
|
Batch GetTokenizedBatch(string prompt, string[] hypotheses) |
|
{ |
|
Batch batch = new Batch(); |
|
|
|
List<int> promptTokens = Tokenize(prompt); |
|
promptTokens.Insert(0, startToken); |
|
|
|
List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray(); |
|
int maxTokenLength = tokenizedHypotheses.Max(x => x.Count); |
|
|
|
|
|
|
|
|
|
int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens |
|
.Append(separatorToken) |
|
.Concat(hypothesis) |
|
.Append(separatorToken) |
|
.Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count))) |
|
.ToArray(); |
|
|
|
|
|
|
|
|
|
|
|
int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1) |
|
.Concat(Enumerable.Repeat(1, hypothesis.Count + 1)) |
|
.Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count))) |
|
.ToArray(); |
|
|
|
batch.BatchCount = hypotheses.Length; |
|
batch.BatchLength = batchedTokens.Length / hypotheses.Length; |
|
batch.BatchedTokens = batchedTokens; |
|
batch.BatchedMasks = batchedMasks; |
|
|
|
return batch; |
|
} |
|
|
|
List<int> Tokenize(string input) |
|
{ |
|
string[] words = input.Split(null); |
|
|
|
List<int> ids = new(); |
|
|
|
foreach (string word in words) |
|
{ |
|
int start = 0; |
|
for(int i = word.Length; i >= 0;i--) |
|
{ |
|
string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start); |
|
int index = Array.IndexOf(vocabularyTokens, subWord); |
|
if (index >= 0) |
|
{ |
|
ids.Add(index + vocabToTokenOffset); |
|
if (i == word.Length) break; |
|
start = i; |
|
i = word.Length + 1; |
|
} |
|
} |
|
} |
|
|
|
return ids; |
|
} |
|
|
|
void OnDestroy() => engine?.Dispose(); |
|
|
|
struct Batch |
|
{ |
|
public int BatchCount; |
|
public int BatchLength; |
|
public int[] BatchedTokens; |
|
public int[] BatchedMasks; |
|
} |
|
} |