From 25618b254188596b6c948e452469089d4f5a0257 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 28 Feb 2023 15:55:27 +0800 Subject: [PATCH 01/27] Add replacement class type --- src/Extensions/TorchCs/TorchUtil.cs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 53711de..ebc76f3 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -57,10 +57,10 @@ public static string ReplaceCodes(string text) // replace 'self' to 'this' text = Regex.Replace(text, @"\bself\.", "this."); // replace field type - text = Regex.Replace(text, @"(object|void) (\w+ = ""\S+?""[,;)])", "string $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = \d+[,;)])", "int $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = \d+\.\d+[,;)])", "double $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = (true|false)[,;)])", "bool $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = ""\S+?""[,;)])", "string $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+[,;)])", "int $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+\.\d+[,;)])", "double $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (true|false)[,;)])", "bool $2"); // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' text = Regex.Replace(text, @"([a-zA-Z_0-9]+) = (\1 \|\| (.*?;))", "$1 = $1 ?? $3 //$2"); @@ -127,6 +127,7 @@ private static string replaceNamespace(string text) text = text.Replace("using optim = torch.optim;", "using optim = TorchSharp.torch.optim;"); text = text.Replace("using DataLoader = torch.utils.data.DataLoader;", "using DataLoader = TorchSharp.torch.utils.data.DataLoader;"); + text = text.Replace("using sys;", ""); text = text.Replace("using math;", ""); text = text.Replace("using os;", ""); text = text.Replace("using time;", ""); @@ -170,15 +171,21 @@ private static string replaceFieldType(string text) } var r = $@"this\.(\S+) = nn\.{methodName}\("; var ms = Regex.Matches(text, r); - if (ms.Count > 0) { - foreach (Match m in ms) { - var name = m.Groups[1].Value; - text = text.Replace($"public object {name};", $"public {fieldType} {name};"); - text = text.Replace($"public void {name};", $"public {fieldType} {name};"); - text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward("); - } + foreach (Match m in ms) { + var name = m.Groups[1].Value; + text = text.Replace($"public object {name};", $"public {fieldType} {name};"); + text = text.Replace($"public void {name};", $"public {fieldType} {name};"); + text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward("); } } + var ms2 = Regex.Matches(text, @"this\.(\S+) = new ([a-zA-Z_][a-zA-Z0-9_]+)\("); + foreach (Match m2 in ms2) { + var name = m2.Groups[1].Value; + var typeName = m2.Groups[2].Value; + text = text.Replace($"public object {name};", $"public {typeName} {name};"); + text = text.Replace($"public void {name};", $"public {typeName} {name};"); + } + text = replaceFieldType3(text); text = Regex.Replace(text, @"public (object|void) (\w+_len;)", "public int $2"); From 48504789cf5cc163fa60019d1b7897e9f8563b4b Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 28 Feb 2023 16:32:22 +0800 Subject: [PATCH 02/27] fix void to object public virtual void vali(object vali_data, object vali_loader, void criterion) => public virtual void vali(object vali_data, object vali_loader, object criterion) --- src/Extensions/TorchCs/TorchUtil.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index ebc76f3..b8cbc9d 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -61,6 +61,7 @@ public static string ReplaceCodes(string text) text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+[,;)])", "int $2"); text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+\.\d+[,;)])", "double $2"); text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (true|false)[,;)])", "bool $2"); + text = Regex.Replace(text, @"\bvoid ([a-zA-Z_][a-zA-Z0-9_]*[ ,);])", "object $1"); // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' text = Regex.Replace(text, @"([a-zA-Z_0-9]+) = (\1 \|\| (.*?;))", "$1 = $1 ?? $3 //$2"); From 86dfbf68927e296a7bfd9e4590c20b05c656a4cb Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 28 Feb 2023 16:45:02 +0800 Subject: [PATCH 03/27] Convert python's [:,:,x[:,:]] syntax, used recurrence , exclude nesting --- src/Extensions/TorchCs/TorchUtil.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index b8cbc9d..b37d3e2 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -553,11 +553,12 @@ private static string replaceTensorList(string text) /// private static string replaceListSlice(string text) { - text = Regex.Replace(text, @"\[([^\[\]]*?)\]", new MatchEvaluator(m => { + text = Regex.Replace(text, @"\[(((?
\[)|(?<-BR>\])|[^\[\]])+)\]", new MatchEvaluator(m => { if (m.Groups[1].Value.Contains(":") == false) { return m.Value; } - var strs = m.Groups[1].Value.Split(','); + var ts = replaceListSlice(m.Groups[1].Value); // recurrence , exclude nesting + var strs = ts.Split(','); List list = new List(); foreach (var str in strs) { if (str.Trim() == "\":\"") { From 4c43cec17ec10de34ee235ee2c3266dd5ba248be Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 10:00:07 +0800 Subject: [PATCH 04/27] Replace type by class area and static method area , More accurate --- src/Extensions/TorchCs/ClassInfo.cs | 513 +++++++++++++++++++++++ src/Extensions/TorchCs/TorchCs.csproj | 2 +- src/Extensions/TorchCs/TorchSharpInfo.cs | 96 +++++ src/Extensions/TorchCs/TorchUtil.cs | 53 ++- 4 files changed, 642 insertions(+), 22 deletions(-) create mode 100644 src/Extensions/TorchCs/ClassInfo.cs create mode 100644 src/Extensions/TorchCs/TorchSharpInfo.cs diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs new file mode 100644 index 0000000..a67d3f2 --- /dev/null +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -0,0 +1,513 @@ +using System.Text.RegularExpressions; +using static System.Net.Mime.MediaTypeNames; + +namespace TorchCs +{ + + + public class ClassFile + { + public string FileName { get; set; } + public string Code { get; set; } + public List ClassInfos { get; set; } + + public List StaticMethods { get; set; } + + public static List LoadFiles(string folder) + { + var files = new List(); + var files2 = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); + foreach (var file in files2) { + var text = File.ReadAllText(file); + ClassFile classFile = new ClassFile(); + classFile.FileName = file; + classFile.Code = text; + classFile.ClassInfos = ClassInfo.AnalysisCode(text); + classFile.StaticMethods = ClassMethod.AnalysisCodeForStaticMethod(text); + } + return files; + } + } + + + public class ClassInfo + { + private const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string classRegex2 = @"public class {name}([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + + public ClassFile File { get; set; } + public string FullClassName { get; set; } + public string ClassName { get; set; } + public bool HasForwardMethod { get; set; } + public ClassConstructor Constructor { get; set; } + public List Fields { get; set; } + public List Methods { get; set; } + + public static List AnalysisCode(string code) + { + List classInfos = new List(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + var match2 = Regex.Match(code, @"public static class ([a-zA-Z_][a-zA-Z0-9._]*) "); + var prefix = match.Groups[1].Value + "." + match2.Groups[1].Value; + + var ms = Regex.Matches(code, classRegex); + foreach (Match m in ms) { + ClassInfo classInfo = new ClassInfo(); + + classInfo.FullClassName = prefix + "." + m.Groups[1].Value; + classInfo.ClassName = m.Groups[1].Value; + var bodyCode = m.Groups[3].Value; + classInfo.Constructor = ClassConstructor.AnalysisCode(bodyCode, classInfo.ClassName); + classInfo.Fields = ClassField.AnalysisCode(bodyCode); + classInfo.Methods = ClassMethod.AnalysisCode(bodyCode); + classInfo.HasForwardMethod = classInfo.Methods.Any(q => q.MethodName == "forward"); + + classInfos.Add(classInfo); + } + var fclass = classInfos.Where(q => q.HasForwardMethod).Select(q => q.ClassName).ToList(); + foreach (var info in classInfos) { + foreach (var item in info.Fields) { + if (fclass.Contains(item.NewType ?? item.Type)) { + item.HasForwardMethod = true; + } + } + } + return classInfos; + } + public string AddNewField(string code) + { + if (Fields.Any(q => q.IsNewField)) { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + foreach (var field in Fields) { + bodyCode = field.AddNewField(bodyCode); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + } + return code; + } + + public string ReplaceNewConstructor(string code, List classInfos) + { + foreach (var classInfo in classInfos) { + code = Regex.Replace(code, $@"\b{classInfo.ClassName}\(", $"new {classInfo.ClassName}("); + } + code = Regex.Replace(code, @"\bnew new ", "new "); + return code; + } + + public string ReplaceCodes(string code) + { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + foreach (var field in Fields) { + bodyCode = field.ReplaceCodes(bodyCode); + } + bodyCode = Constructor.ReplaceCodes(bodyCode); + foreach (var method in Methods) { + bodyCode = method.ReplaceCodes(bodyCode, Fields); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + return code; + } + public override string ToString() + { + return $"class: {ClassName}"; + } + } + public class ClassConstructor + { + private const string constructorRegex = @"public {name}\(([^)]*?)\)(.*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + + public string ClassName { get; set; } + + public List Paramenters { get; set; } + public List Variables { get; set; } + + public static ClassConstructor AnalysisCode(string code, string className) + { + ClassConstructor classConstructor = new ClassConstructor(); + classConstructor.ClassName = className; + var reg = constructorRegex.Replace("{name}", className); + var m = Regex.Match(code, reg); + classConstructor.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[1].Value, m.Groups[3].Value); + classConstructor.Variables = ClassMethodVariable.AnalysisCode(m.Groups[3].Value, classConstructor.Paramenters); + return classConstructor; + } + + + + public string ReplaceCodes(string code) + { + code = Regex.Replace(code, constructorRegex.Replace("{name}", ClassName), new MatchEvaluator(m => { + var ParamenterCode = m.Groups[1].Value; + foreach (var paramenter in Paramenters) { + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + } + var BodyCode = m.Groups[3].Value; + foreach (var variable in Variables) { + BodyCode = variable.ReplaceCodes(BodyCode); + } + return $"public {ClassName}({ParamenterCode}){m.Groups[2].Value}{{{BodyCode}}}"; + })); + return code; + } + } + + public class ClassField + { + public string Type { get; set; } + public string NewType { get; set; } + public string FieldName { get; set; } + public bool IsNewField { get; set; } + public bool HasForwardMethod { get; set; } + + public static List AnalysisCode(string code) + { + List classFields = new List(); + HashSet fields = new HashSet(); + var ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_][a-zA-Z0-9_]*);"); + foreach (Match match in ms) { + ClassField field = new ClassField(); + field.Type = match.Groups[1].Value; + field.FieldName = match.Groups[2].Value; + fields.Add(field.FieldName); + classFields.Add(field); + } + ms = Regex.Matches(code, @"\bthis\.([a-zA-Z_][a-zA-Z0-9_]*)[ \t\r\n,;)\[]"); + foreach (Match m in ms) { + if (fields.Add(m.Groups[1].Value)) { + ClassField field = new ClassField(); + field.Type = "object"; + field.FieldName = m.Groups[1].Value; + field.IsNewField = true; + classFields.Add(field); + } + } + + var nnMethods = TorchSharpInfo.Instance.nnMethods; + foreach (var method in nnMethods) { + var fieldType = method.ReturnType.Name; + var methodName = method.Name; + if (methodName == "ModuleDict" || methodName == "ModuleList") { continue; } + + var r = $@"this\.(\S+) = nn\.{methodName}\("; + var ms3 = Regex.Matches(code, r); + foreach (Match m in ms3) { + var name = m.Groups[1].Value; + var f = classFields.FirstOrDefault(q => q.FieldName == name); + if (f != null) { f.NewType = fieldType; f.HasForwardMethod = true; } + } + } + + var ms2 = Regex.Matches(code, @"this\.(\S+) = new ([a-zA-Z_][a-zA-Z0-9_]+)\("); + foreach (Match m2 in ms2) { + var name = m2.Groups[1].Value; + var typeName = m2.Groups[2].Value; + var f = classFields.FirstOrDefault(q => q.FieldName == name); + if (f != null) { f.NewType = typeName; } + } + + foreach (var field1 in classFields) { + if (field1.NewType != null) { continue; } + var name = field1.FieldName; + if (code.Contains($"if (this.{name})") || code.Contains($"if (!this.{name})")) { + field1.NewType = "bool"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=) (false|true)") || Regex.IsMatch(code, $@"(\(|&& |\|\| )!?this\.{name} (&&|\|\|)") || Regex.IsMatch(code, $@"(&&|\|\|) !?this\.{name}(\)| &&| \|\|)")) { + field1.NewType = "bool"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(code, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + field1.NewType = "string"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + field1.NewType = "double"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + field1.NewType = "int"; + } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { + field1.NewType = "Tensor"; + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + field1.NewType = "int"; + } else if (field1.Type == "object" && Regex.IsMatch(name, "^.*(_path|_name|_dir)$")) { + field1.NewType = "string"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { + // classMethodParamenter.NewType = "double"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { + // classMethodParamenter.NewType = "double"; + } + } + return classFields; + } + public string AddNewField(string code) + { + if (IsNewField) { + return $"\r\n\t\t\tpublic {NewType ?? Type} {FieldName};{code}"; + } + return code; + } + + public string ReplaceCodes(string code) + { + if (NewType == null || NewType == Type) { return code; } + return code.Replace($"public {Type} {FieldName};", $"public {NewType} {FieldName};"); + } + + public override string ToString() + { + return $"field: {NewType ?? Type} {FieldName}"; + } + + } + + public class ClassMethod + { + private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + // public static int get_q_k(int input_size, int window_size, object stride, object device) + + public string MethodName { get; set; } + public string ReturnType { get; set; } + public string NewReturnType { get; set; } + public bool IsForwardMethod { get; set; } + + public List Paramenters { get; set; } = new List(); + public List Variables { get; set; } = new List(); + + public static List AnalysisCodeForStaticMethod(string code) + { + List classMethods = new List(); + var ms = Regex.Matches(code, methodRegex3); + foreach (Match m in ms) { + ClassMethod classMethod = new ClassMethod(); + + classMethod.ReturnType = m.Groups[2].Value; + classMethod.MethodName = m.Groups[3].Value; + classMethod.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[4].Value, m.Groups[5].Value); + classMethod.Variables = ClassMethodVariable.AnalysisCode(m.Groups[5].Value, classMethod.Paramenters); + classMethods.Add(classMethod); + } + return classMethods; + } + public static List AnalysisCode(string code) + { + List classMethods = new List(); + var ms = Regex.Matches(code, methodRegex); + foreach (Match m in ms) { + ClassMethod classMethod = new ClassMethod(); + classMethod.ReturnType = m.Groups[2].Value; + classMethod.MethodName = m.Groups[3].Value; + classMethod.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[4].Value, m.Groups[5].Value); + classMethod.Variables = ClassMethodVariable.AnalysisCode(m.Groups[5].Value, classMethod.Paramenters); + classMethod.IsForwardMethod = classMethod.MethodName == "forward"; + classMethods.Add(classMethod); + } + return classMethods; + } + + public string ReplaceCodes(string code, List fields = null) + { + code = Regex.Replace(code, methodRegex2.Replace("{name}", MethodName), new MatchEvaluator(m => { + var ParamenterCode = m.Groups[3].Value; + foreach (var paramenter in Paramenters) { + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + } + var bodyCode = m.Groups[4].Value; + foreach (var variable in Variables) { + bodyCode = variable.ReplaceCodes(bodyCode); + } + if (fields != null) { + foreach (var field in fields) { + if (field.HasForwardMethod || IsForwardMethod) { + bodyCode = Regex.Replace(bodyCode, @$"\bthis\.{field.FieldName}\(", $"this.{field.FieldName}.forward("); + bodyCode = Regex.Replace(bodyCode, @$"\bthis\.{field.FieldName}(\[([a-zA-Z_][a-zA-Z_0-9]*|\^?[0-9]+)\])\(", $"this.{field.FieldName}$1.forward("); + } + } + } + if (NewReturnType == null) { + if (ReturnType.StartsWith("Tuple<")) { + NewReturnType = ReturnType.Replace("Tuple<", "("); + NewReturnType = NewReturnType.Substring(0, NewReturnType.Length - 1) + ")"; + } + } + return $"public {m.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; + })); + return code; + } + public override string ToString() + { + if (NewReturnType != null) { + return $"method: {NewReturnType} {MethodName}"; + } + return $"method: {ReturnType} {MethodName}"; + } + } + + public class ClassMethodParamenter + { + public string ParamenterName { get; set; } + public string Type { get; set; } + public string NewType { get; set; } + public string DefaultValue { get; set; } + + public static List AnalysisCode(string code, string text) + { + var fieldsRegex = TorchSharpInfo.Instance.TensorFieldRegex; + var methodRegex = TorchSharpInfo.Instance.TensorMethodRegex; + + + List classMethodParamenters = new List(); + if (string.IsNullOrEmpty(code)) { return classMethodParamenters; } + + var strs = Regex.Matches(code, "(.*?) ([a-zA-Z_][a-zA-Z_0-9]*)( = ([^,]+))?(,|$)"); + + foreach (Match str in strs) { + ClassMethodParamenter classMethodParamenter = new ClassMethodParamenter(); + classMethodParamenters.Add(classMethodParamenter); + classMethodParamenter.Type = str.Groups[1].Value.Trim(); + classMethodParamenter.ParamenterName = str.Groups[2].Value.Trim(); + var name = classMethodParamenter.ParamenterName; + if (name == "inputs") { + + } + if (str.Groups[3].Success) { + classMethodParamenter.DefaultValue = str.Groups[4].Value.Trim(); + + if (classMethodParamenter.DefaultValue == "true" || classMethodParamenter.DefaultValue == "false") { + classMethodParamenter.NewType = "bool"; + } else if (classMethodParamenter.DefaultValue.StartsWith("\"")) { + classMethodParamenter.NewType = "string"; + } else if (Regex.IsMatch(classMethodParamenter.DefaultValue, @"\-?\d+\.\d+")) { + classMethodParamenter.NewType = "double"; + } else if (Regex.IsMatch(classMethodParamenter.DefaultValue, @"\-?\d+")) { + classMethodParamenter.NewType = "int"; + } else if (classMethodParamenter.DefaultValue == "null") { + if (Regex.IsMatch(text, @$"{name} = {name} \?\? [a-zA-Z_][a-zA-Z_0-9]* [\+\-\*\/] [a-zA-Z_][a-zA-Z_0-9]*;")) { + classMethodParamenter.NewType = "int?"; + } + } + if (classMethodParamenter.NewType != null) { continue; } + } + if (text.Contains($"if ({name})") || text.Contains($"if (!{name})")) { + classMethodParamenter.NewType = "bool"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=) (false|true)") || Regex.IsMatch(text, $@"(\(|&& |\|\| )!?{name} (&&|\|\|)") || Regex.IsMatch(text, $@"(&&|\|\|) !?{name}(\)| &&| \|\|)")) { + classMethodParamenter.NewType = "bool"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.(split|startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + classMethodParamenter.NewType = "string"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + classMethodParamenter.NewType = "doulbe"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + classMethodParamenter.NewType = "int"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\[[^\]]*?TensorIndex\.")) { + classMethodParamenter.NewType = "Tensor"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{fieldsRegex}[ ,;)\[]")) { + classMethodParamenter.NewType = "Tensor"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { + classMethodParamenter.NewType = "Tensor"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + classMethodParamenter.NewType = "int"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^.*(_path|_name|_dir)$")) { + classMethodParamenter.NewType = "string"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { + // classMethodParamenter.NewType = "double"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { + // classMethodParamenter.NewType = "double"; + } + } + return classMethodParamenters; + } + + public string ReplaceCodes(string code) + { + if (NewType == null || NewType == Type) { return code; } + return Regex.Replace(code, $@"\b{Type} {ParamenterName}\b", $"{NewType} {ParamenterName}"); + } + + public override string ToString() + { + if (NewType != null) { + return $"paramenter: {NewType} {ParamenterName}"; + } + return $"paramenter: {Type} {ParamenterName}"; + } + } + + public class ClassMethodVariable + { + public string Type { get; set; } + public string NewType { get; set; } + public string HiddenType { get; set; } + public string VariableName { get; set; } + + public static List AnalysisCode(string code, List paramenters) + { + List classMethodVariables = new List(); + var texts = code.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries); + + HashSet names = new HashSet(); + names.Add("_"); + foreach (var paramenter in paramenters) { + names.Add(paramenter.ParamenterName); + } + foreach (var text in texts) { + var m = Regex.Match(text, @"\t*([a-zA-Z_][a-zA-Z0-9_]*) ([a-zA-Z_][a-zA-Z0-9_]*) = "); + if (m.Success) { + if (names.Add(m.Groups[1].Value)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.Type = m.Groups[1].Value; + classMethodVariable.VariableName = m.Groups[2].Value; + + classMethodVariables.Add(classMethodVariable); + } + continue; + } + m = Regex.Match(text, @"\t*([a-zA-Z_][a-zA-Z0-9_]*) = "); + if (m.Success) { + if (names.Add(m.Groups[1].Value)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.VariableName = m.Groups[1].Value; + classMethodVariables.Add(classMethodVariable); + } + continue; + } + + m = Regex.Match(text, @"\t*\(([^)]+)\) = "); + if (m.Success) { + var str = m.Groups[1].Value; + var sp = str.Split(','); + foreach (var sp1 in sp) { + var s = sp1.Trim(); + if (names.Add(s)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.VariableName = m.Groups[1].Value; + classMethodVariables.Add(classMethodVariable); + } + } + continue; + } + } + return classMethodVariables; + } + + public string ReplaceCodes(string code) + { + return code; + //if (Type != null && NewType != Type) { + // code = Regex.Replace(code, $@"\b{Type} {VariableName}", $"{NewType} {VariableName}"); + //} + //return code; + } + + public override string ToString() + { + if (NewType != null) { + return $"variable: {NewType} {VariableName}"; + } + return $"variable: {Type} {VariableName}"; + } + } + + + +} diff --git a/src/Extensions/TorchCs/TorchCs.csproj b/src/Extensions/TorchCs/TorchCs.csproj index 27701f2..8a27351 100644 --- a/src/Extensions/TorchCs/TorchCs.csproj +++ b/src/Extensions/TorchCs/TorchCs.csproj @@ -3,7 +3,7 @@ net6.0 enable - enable + disable diff --git a/src/Extensions/TorchCs/TorchSharpInfo.cs b/src/Extensions/TorchCs/TorchSharpInfo.cs new file mode 100644 index 0000000..3fcdf99 --- /dev/null +++ b/src/Extensions/TorchCs/TorchSharpInfo.cs @@ -0,0 +1,96 @@ +using System.Reflection; + +namespace TorchCs +{ + public class TorchSharpInfo + { + //public TorchSharpMethodList TorchSharpMethods; + public Type nnType; + public MethodInfo[] nnMethods; + public List nnModelNames; + + public Type torchType; + public MethodInfo[] torchMethods; + + public Type TensorType; + + public MethodInfo[] TensorMethods; + public string TensorFieldRegex; + public string TensorMethodRegex; + + public static TorchSharpInfo Instance=new TorchSharpInfo(); + + private TorchSharpInfo() + { + nnType = typeof(TorchSharp.torch.nn); + nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); + + nnModelNames = new List(); + foreach (var method in nnMethods) { + if (method.Name == "ModuleDict" || method.Name == "ModuleList") { continue; } + nnModelNames.Add(method.ReturnType.Name); + } + nnModelNames = nnModelNames.Distinct().ToList(); + + TensorType = typeof(TorchSharp.torch.Tensor); + var fields = TensorType.GetFields(); + var properties = TensorType.GetProperties(); + HashSet fs = new HashSet(); + foreach (var fieldInfo in fields) { fs.Add(fieldInfo.Name); } + foreach (var fieldInfo in properties) { fs.Add(fieldInfo.Name); } + fs.Remove("device"); + TensorFieldRegex = "(" + string.Join("|", fs) + ")"; + TensorMethods = TensorType.GetMethods(BindingFlags.Public | BindingFlags.Instance); + fs.Clear(); + foreach (var fieldInfo in TensorMethods) { fs.Add(fieldInfo.Name); } + TensorMethodRegex = "(" + string.Join("|", fs) + ")"; + + var torchType = typeof(TorchSharp.torch); + torchMethods = torchType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); + } + } + + //public class TorchSharpMethodList : List + //{ + // private static TorchSharpMethodList _TorchSharpMethods; + // private static TorchSharpMethodList GetTorchSharpMethods() + // { + // if (_TorchSharpMethods == null) { + + // } + // return _TorchSharpMethods; + // } + // public TorchSharpMethod GetMethod(string methodName, List paramenters) + // { + // return null; + // } + + //} + + //public class TorchSharpMethod + //{ + // public string MethodName { get; set; } + // public string TypeName { get; set; } + // public List Paramenters { get; set; } + // public string ReplaceCodes(string code) + // { + + // return code; + // } + //} + + //public class MethodParamenter + //{ + // public int Index { get; set; } + // public string Name { get; set; } + // public string TypeName { get; set; } + // public bool IsOptional { get; set; } + + // public string ReplaceCodes(string code) + // { + + // return code; + // } + //} + +} diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index b37d3e2..05aeaea 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -65,9 +65,21 @@ public static string ReplaceCodes(string text) // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' text = Regex.Replace(text, @"([a-zA-Z_0-9]+) = (\1 \|\| (.*?;))", "$1 = $1 ?? $3 //$2"); - text = replaceNamespace(text); text = replaceConstructor(text); + text = replaceListSlice(text); + + // Replace type by class area and static method area + var classInfos = ClassInfo.AnalysisCode(text); + foreach (var classInfo in classInfos) { + text = classInfo.AddNewField(text); // Add missing fields + text = classInfo.ReplaceCodes(text); + } + var sss = ClassMethod.AnalysisCodeForStaticMethod(text); + foreach (var item in sss) { + text = item.ReplaceCodes(text); + } + text = replaceFieldType(text); text = replaceMethodParameterName(text); text = replaceMethodParamenterType(text); @@ -78,7 +90,6 @@ public static string ReplaceCodes(string text) text = replaceForwardMethod(text); text = replaceCallForwardMethod(text); - text = replaceListSlice(text); text = replaceTensorList(text); text = replaceIsType(text); @@ -212,7 +223,22 @@ private static string replaceFieldType3(string text) if (ms.Count > 0) { foreach (Match m in ms) { var name = m.Groups[2].Value; - if (text.Contains($"this.{name} = {name};")) { + if (text.Contains($"if (this.{name})") || text.Contains($"if (!this.{name})") || text.Contains($"if (this.{name} == true)") || text.Contains($"if (this.{name} == false)")) { + text = text.Replace($"public object {name};", $"public bool {name};"); + text = text.Replace($"public void {name};", $"public bool {name};"); + } else if (text.Contains($"this.{name} = false") || text.Contains($"this.{name} = true")) { + text = text.Replace($"public object {name};", $"public bool {name};"); + text = text.Replace($"public void {name};", $"public bool {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + text = text.Replace($"public object {name};", $"public string {name};"); + text = text.Replace($"public void {name};", $"public string {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + text = text.Replace($"public object {name};", $"public doulbe {name};"); + text = text.Replace($"public void {name};", $"public doulbe {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + text = text.Replace($"public object {name};", $"public int {name};"); + text = text.Replace($"public void {name};", $"public int {name};"); + } else if (text.Contains($"this.{name} = {name};")) { if (Regex.IsMatch(text, @$"int {name}\b")) { text = text.Replace($"public object {name};", $"public int {name};"); text = text.Replace($"public void {name};", $"public int {name};"); @@ -229,21 +255,6 @@ private static string replaceFieldType3(string text) text = text.Replace($"public object {name};", $"public bool {name};"); text = text.Replace($"public void {name};", $"public bool {name};"); } - } else if (text.Contains($"if (this.{name})") || text.Contains($"if (!this.{name})") || text.Contains($"if (this.{name} == true)") || text.Contains($"if (this.{name} == false)")) { - text = text.Replace($"public object {name};", $"public bool {name};"); - text = text.Replace($"public void {name};", $"public bool {name};"); - } else if (text.Contains($"this.{name} = false") || text.Contains($"this.{name} = true")) { - text = text.Replace($"public object {name};", $"public bool {name};"); - text = text.Replace($"public void {name};", $"public bool {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { - text = text.Replace($"public object {name};", $"public string {name};"); - text = text.Replace($"public void {name};", $"public string {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { - text = text.Replace($"public object {name};", $"public doulbe {name};"); - text = text.Replace($"public void {name};", $"public doulbe {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { - text = text.Replace($"public object {name};", $"public int {name};"); - text = text.Replace($"public void {name};", $"public int {name};"); } } @@ -457,9 +468,9 @@ private static string replaceForwardMethod(string text) text = text.Replace(" Tuple> forward(", " (Tensor, List) forward("); text = text.Replace(" object forward(", " Tensor forward("); text = text.Replace(" void forward(", " Tensor forward("); - text = text.Replace(" forward(object x", " forward(Tensor x"); - text = text.Replace(" forward(object t", " forward(Tensor t"); - text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values"); + //text = text.Replace(" forward(object x", " forward(Tensor x"); + //text = text.Replace(" forward(object t", " forward(Tensor t"); + //text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values"); return text; } /// From a423599438fbc0d14ed6ec8aa40342a97a2464b4 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 10:56:07 +0800 Subject: [PATCH 05/27] Class creation of standardized import using DataEmbedding = layers.Embed.DataEmbedding; this.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout); => this.enc_embedding = new DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout); --- src/Extensions/TorchCs/TorchUtil.cs | 47 +++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 05aeaea..1b8d966 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -19,6 +19,7 @@ using System.Text.RegularExpressions; using System.Xml.Linq; using TorchSharp; +using static TorchSharp.torch; namespace TorchCs { @@ -33,9 +34,18 @@ public class TorchUtil public static void ReplaceFolder(string folder) { var files = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); + HashSet classNames = new HashSet(); foreach (var file in files) { var text = File.ReadAllText(file); - File.WriteAllText(file, ReplaceCodes(text)); + getClassName(text, classNames); + } + classNames.Remove("torch"); + classNames.Remove("nn"); + classNames.Remove("F"); + + foreach (var file in files) { + var text = File.ReadAllText(file); + File.WriteAllText(file, ReplaceCodes(text, classNames)); } } /// @@ -52,7 +62,7 @@ public static void ReplaceFile(string file) /// /// /// - public static string ReplaceCodes(string text) + public static string ReplaceCodes(string text, HashSet classNames = null) { // replace 'self' to 'this' text = Regex.Replace(text, @"\bself\.", "this."); @@ -68,7 +78,7 @@ public static string ReplaceCodes(string text) text = replaceNamespace(text); text = replaceConstructor(text); text = replaceListSlice(text); - + text = replaceNewClass(text, classNames); // Replace type by class area and static method area var classInfos = ClassInfo.AnalysisCode(text); foreach (var classInfo in classInfos) { @@ -673,6 +683,37 @@ private static string replaceStringToNetstandard(string text) } + private static string replaceNewClass(string text, HashSet classNames) + { + if (classNames == null) { return text; } + const string classRegex = @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"; + + List names = new List(); + var ms = Regex.Matches(text, classRegex); + foreach (Match m in ms) { + if (classNames.Contains(m.Groups[1].Value)) { + names.Add(m.Groups[1].Value); + } + } + if (names.Count == 0) { return text; } + + var namereg = string.Join("|", names); + text = Regex.Replace(text, $@"\b({namereg})\(", "new $1("); + text = Regex.Replace(text, @"\bnew new ", "new "); + return text; + } + + private static void getClassName(string text, HashSet classNames) + { + const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)"; + var ms = Regex.Matches(text, classRegex); + foreach (Match m in ms) { + classNames.Add(m.Groups[1].Value); + } + } + + + } } From c42e803e86a3cf3fe6a497a9929a0e668babe751 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 11:05:29 +0800 Subject: [PATCH 06/27] replace forward method's ReturnType object => Tensor --- src/Extensions/TorchCs/ClassInfo.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index a67d3f2..474a764 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -329,6 +329,9 @@ public string ReplaceCodes(string code, List fields = null) if (ReturnType.StartsWith("Tuple<")) { NewReturnType = ReturnType.Replace("Tuple<", "("); NewReturnType = NewReturnType.Substring(0, NewReturnType.Length - 1) + ")"; + if (IsForwardMethod) { + NewReturnType = NewReturnType.Replace("object", "Tensor"); + } } } return $"public {m.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; From 3504e68b86ab7122ed618831ba0e35a234f76f95 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 13:38:58 +0800 Subject: [PATCH 07/27] Match method names beginning with @ . Set common name type . --- src/Extensions/TorchCs/ClassInfo.cs | 37 ++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 474a764..4c1bebe 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -1,10 +1,7 @@ using System.Text.RegularExpressions; -using static System.Net.Mime.MediaTypeNames; namespace TorchCs { - - public class ClassFile { public string FileName { get; set; } @@ -227,9 +224,11 @@ public static List AnalysisCode(string code) field1.NewType = "int"; } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { field1.NewType = "Tensor"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(dropout)$")) { + field1.NewType = "double"; + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { field1.NewType = "int"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^.*(_path|_name|_dir)$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(name|.*(_path|_name|_dir))$")) { field1.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; @@ -262,9 +261,9 @@ public override string ToString() public class ClassMethod { - private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; - private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; // public static int get_q_k(int input_size, int window_size, object stride, object device) public string MethodName { get; set; } @@ -331,6 +330,24 @@ public string ReplaceCodes(string code, List fields = null) NewReturnType = NewReturnType.Substring(0, NewReturnType.Length - 1) + ")"; if (IsForwardMethod) { NewReturnType = NewReturnType.Replace("object", "Tensor"); + NewReturnType = NewReturnType.Replace("void", "Tensor"); + } + } else if (ReturnType == "void") { + var ms = Regex.Matches(bodyCode, "return ([^;]*);"); + var max = 0; + foreach (Match item in ms) { + var num = item.Groups[1].Value.Split(','); + max = Math.Max(max, num.Length); + } + if (max == 1) { + NewReturnType = "object"; + } else if (max > 1) { + NewReturnType = "("; + for (int i = 0; i < max; i++) { + if (i > 0) { NewReturnType += ","; } + NewReturnType += "object"; + } + NewReturnType += ")"; } } } @@ -408,9 +425,11 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.NewType = "Tensor"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { classMethodParamenter.NewType = "Tensor"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(dropout)$")) { + classMethodParamenter.NewType = "double"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { classMethodParamenter.NewType = "int"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^.*(_path|_name|_dir)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(name|.*(_path|_name|_dir))$")) { classMethodParamenter.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; From 2033357870da6632102c76022c927fc18af860f2 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 15:49:05 +0800 Subject: [PATCH 08/27] Determine the field type of the class and the parameter type of the method according to the parameters of the 'nn.XXX' and 'torch.XXX' methods. --- src/Extensions/TorchCs/ClassInfo.cs | 34 +++- src/Extensions/TorchCs/TorchSharpInfo.cs | 192 +++++++++++++++++------ src/Extensions/TorchCs/TorchUtil.cs | 89 +---------- 3 files changed, 175 insertions(+), 140 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 4c1bebe..6578eaa 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -1,4 +1,5 @@ using System.Text.RegularExpressions; +using static System.Net.Mime.MediaTypeNames; namespace TorchCs { @@ -224,16 +225,24 @@ public static List AnalysisCode(string code) field1.NewType = "int"; } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { field1.NewType = "Tensor"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(dropout)$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(dropout|.*_dropout)$")) { field1.NewType = "double"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|epoch|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { field1.NewType = "int"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(name|.*(_path|_name|_dir))$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, "^(name|path|dir|.*(_path|_name|_dir))$")) { field1.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { // classMethodParamenter.NewType = "double"; + } else { + var type = TorchSharpInfo.Instance.FindTypeBy_nn(code, "this." + field1.FieldName); + if (type == null) { + type = TorchSharpInfo.Instance.FindTypeBy_torch(code, "this." + field1.FieldName); + } + if (type != null) { + field1.NewType = type; + } } } return classFields; @@ -388,9 +397,9 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.Type = str.Groups[1].Value.Trim(); classMethodParamenter.ParamenterName = str.Groups[2].Value.Trim(); var name = classMethodParamenter.ParamenterName; - if (name == "inputs") { + //if (name == "inputs") { - } + //} if (str.Groups[3].Success) { classMethodParamenter.DefaultValue = str.Groups[4].Value.Trim(); @@ -425,17 +434,26 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.NewType = "Tensor"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { classMethodParamenter.NewType = "Tensor"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(dropout)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(dropout|.*_dropout)$")) { classMethodParamenter.NewType = "double"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|epoch|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { classMethodParamenter.NewType = "int"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(name|.*(_path|_name|_dir))$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(name|path|dir|.*(_path|_name|_dir))$")) { classMethodParamenter.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { // classMethodParamenter.NewType = "double"; + } else { + var type = TorchSharpInfo.Instance.FindTypeBy_nn(text, classMethodParamenter.ParamenterName); + if (type == null) { + type = TorchSharpInfo.Instance.FindTypeBy_torch(text, classMethodParamenter.ParamenterName); + } + if (type != null) { + classMethodParamenter.NewType = type; + } } + } return classMethodParamenters; } diff --git a/src/Extensions/TorchCs/TorchSharpInfo.cs b/src/Extensions/TorchCs/TorchSharpInfo.cs index 3fcdf99..8079ce5 100644 --- a/src/Extensions/TorchCs/TorchSharpInfo.cs +++ b/src/Extensions/TorchCs/TorchSharpInfo.cs @@ -1,10 +1,18 @@ using System.Reflection; +using System.Text.RegularExpressions; +using TorchSharp; namespace TorchCs { public class TorchSharpInfo { - //public TorchSharpMethodList TorchSharpMethods; + private Dictionary dict=new Dictionary() { + {"Int64","long" }, + {"Int32","int" }, + {"String","string" }, + {"Single","float" }, + {"Double","double" }, + }; public Type nnType; public MethodInfo[] nnMethods; public List nnModelNames; @@ -18,7 +26,10 @@ public class TorchSharpInfo public string TensorFieldRegex; public string TensorMethodRegex; - public static TorchSharpInfo Instance=new TorchSharpInfo(); + private TorchSharpMethodList nn_methods; + private TorchSharpMethodList torch_methods; + + public static TorchSharpInfo Instance = new TorchSharpInfo(); private TorchSharpInfo() { @@ -47,50 +58,143 @@ private TorchSharpInfo() var torchType = typeof(TorchSharp.torch); torchMethods = torchType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); + + nn_methods = new TorchSharpMethodList(nnMethods); + torch_methods = new TorchSharpMethodList(torchMethods); + } + + public string FindTypeBy_nn(string code, string text) + { + var names = nn_methods.Select(q => q.MethodName).Distinct().ToList(); + string reg = $@"\bnn\.({string.Join("|", names)})\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(code, reg); + foreach (Match m in ms) { + if (m.Value.Contains(text) == false) { continue; } + var p = m.Groups[2].Value; + var type = FindTypeBy_nn(p, text); + if (type != null) { return type; } + var ms2 = Regex.Matches(p.Trim(), @"(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|(?\[)|(?<-BR3>\])|(?\<)|(?<-BR3>\>)|[^\<\>\(\)\{\}\[\]])+?)(,|$)"); + var ps = new List(); + foreach (Match m2 in ms2) { + ps.Add(m2.Groups[1].Value.Trim()); + } + if (ps.Contains(text) == false) { continue; } + + var methodName = m.Groups[1].Value; + var methods = nn_methods.Where(q => q.MethodName == methodName).ToList(); + foreach (var method in methods) { + if (method.Check(ps)) { + var index = ps.IndexOf(text); + var pi = method.Paramenters[index]; + if (pi.IsGenericType == false) { + if (dict.TryGetValue(pi.TypeName,out string t)) { + return t; + } + return pi.TypeName; + } + } + } + } + return null; + } + public string FindTypeBy_torch(string code, string text) + { + var names = torch_methods.Select(q => q.MethodName).Distinct().ToList(); + string reg = $@"\btorch\.({string.Join("|", names)})\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(code, reg); + foreach (Match m in ms) { + if (m.Value.Contains(text) == false) { continue; } + var p = m.Groups[2].Value; + var type = FindTypeBy_nn(p, text); + if (type != null) { return type; } + var ms2 = Regex.Matches(p.Trim(), @"(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|(?\[)|(?<-BR3>\])|(?\<)|(?<-BR3>\>)|[^\<\>\(\)\{\}\[\]])+?)(,|$)"); + var ps = new List(); + foreach (Match m2 in ms2) { + ps.Add(m2.Groups[1].Value.Trim()); + } + if (ps.Contains(text) == false) { continue; } + + var methodName = m.Groups[1].Value; + var methods = torch_methods.Where(q => q.MethodName == methodName).ToList(); + foreach (var method in methods) { + if (method.Check(ps)) { + var index = ps.IndexOf(text); + var pi = method.Paramenters[index]; + if (pi.IsGenericType == false) { + if (dict.TryGetValue(pi.TypeName, out string t)) { + return t; + } + return pi.TypeName; + } + } + } + } + return null; + } + + + + } + + public class TorchSharpMethodList : List + { + public TorchSharpMethodList(MethodInfo[] methods) + { + foreach (var method in methods) { + Add(new TorchSharpMethod(method)); + } + } + } + + public class TorchSharpMethod + { + public string MethodName { get; set; } + public string ReturnType { get; set; } + public List Paramenters { get; set; } + + public TorchSharpMethod() { } + public TorchSharpMethod(MethodInfo methodInfo) + { + MethodName = methodInfo.Name; + ReturnType = methodInfo.ReturnType.Name; + Paramenters = new List(); + var ps = methodInfo.GetParameters(); + for (int i = 0; i < ps.Length; i++) { + Paramenters.Add(new MethodParamenter(i, ps[i])); + } + } + public bool Check(List ps) + { + if (Paramenters.Count < ps.Count) { return false; } + foreach (var p in ps) { + if (p.Contains(":")) { + var name = p.Substring(0, p.IndexOf(':')); + if (Paramenters.Any(q => q.Name == name) == false) { + return false; + } + } + } + return true; } } - //public class TorchSharpMethodList : List - //{ - // private static TorchSharpMethodList _TorchSharpMethods; - // private static TorchSharpMethodList GetTorchSharpMethods() - // { - // if (_TorchSharpMethods == null) { - - // } - // return _TorchSharpMethods; - // } - // public TorchSharpMethod GetMethod(string methodName, List paramenters) - // { - // return null; - // } - - //} - - //public class TorchSharpMethod - //{ - // public string MethodName { get; set; } - // public string TypeName { get; set; } - // public List Paramenters { get; set; } - // public string ReplaceCodes(string code) - // { - - // return code; - // } - //} - - //public class MethodParamenter - //{ - // public int Index { get; set; } - // public string Name { get; set; } - // public string TypeName { get; set; } - // public bool IsOptional { get; set; } - - // public string ReplaceCodes(string code) - // { - - // return code; - // } - //} + public class MethodParamenter + { + public int Index { get; set; } + public string Name { get; set; } + public string TypeName { get; set; } + public bool IsGenericType { get; set; } + public bool IsOptional { get; set; } + + public MethodParamenter() { } + public MethodParamenter(int index, ParameterInfo parameter) + { + Index = index; + Name = parameter.Name; + TypeName = parameter.ParameterType.Name; + IsOptional = parameter.IsOptional; + IsGenericType = parameter.ParameterType.IsGenericType; + } + } } diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 1b8d966..4868fcb 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -79,6 +79,7 @@ public static string ReplaceCodes(string text, HashSet classNames = null text = replaceConstructor(text); text = replaceListSlice(text); text = replaceNewClass(text, classNames); + text = replaceMethodParameterName(text); // Replace type by class area and static method area var classInfos = ClassInfo.AnalysisCode(text); foreach (var classInfo in classInfos) { @@ -91,16 +92,11 @@ public static string ReplaceCodes(string text, HashSet classNames = null } text = replaceFieldType(text); - text = replaceMethodParameterName(text); text = replaceMethodParamenterType(text); text = replaceMathMethod(text); text = replaceStringToEnum(text); text = replaceMethodAlias(text); - text = replaceForwardMethod(text); - text = replaceCallForwardMethod(text); - - text = replaceTensorList(text); text = replaceIsType(text); @@ -465,89 +461,6 @@ private static string replaceMathMethod(string text) text = Regex.Replace(text, @"\bmath\.inf\b", "double.PositiveInfinity"); return text; } - /// - /// Replace forward method's return type and forward method's parameter type - /// - /// - /// - private static string replaceForwardMethod(string text) - { - text = text.Replace(" Tuple", " (Tensor, Tensor)"); - text = text.Replace(" Tuple forward(", " (Tensor, Tensor) forward("); - text = text.Replace(" object[] forward(", " (Tensor, Tensor) forward("); - text = text.Replace(" Tuple> forward(", " (Tensor, List) forward("); - text = text.Replace(" object forward(", " Tensor forward("); - text = text.Replace(" void forward(", " Tensor forward("); - //text = text.Replace(" forward(object x", " forward(Tensor x"); - //text = text.Replace(" forward(object t", " forward(Tensor t"); - //text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values"); - return text; - } - /// - /// Replace common forward method calls - /// - /// - /// - private static string replaceCallForwardMethod(string text) - { - text = Regex.Replace(text, @"\bthis\.inner_attention\(", "this.inner_attention.forward("); - text = Regex.Replace(text, @"\bthis\.dropout\(", "this.dropout.forward("); - text = Regex.Replace(text, @"\bthis\.attention\(", "this.attention.forward("); - text = Regex.Replace(text, @"\bthis\.self_attention\(", "this.self_attention.forward("); - text = Regex.Replace(text, @"\bthis\.cross_attention\(", "this.cross_attention.forward("); - text = Regex.Replace(text, @"\bthis\.projection\(", "this.projection.forward("); - text = Regex.Replace(text, @"\bthis\.activation\(", "this.activation.forward("); - text = Regex.Replace(text, @"\bthis\.norm\(", "this.norm.forward("); - text = Regex.Replace(text, @"\bthis\.conv\(", "this.conv.forward("); - text = Regex.Replace(text, @"\bthis\.decomp\(", "this.decomp.forward("); - text = Regex.Replace(text, @"\bthis\.decomp1\(", "this.decomp1.forward("); - text = Regex.Replace(text, @"\bthis\.decomp2\(", "this.decomp2.forward("); - text = Regex.Replace(text, @"\bthis\.decomp3\(", "this.decomp3.forward("); - text = Regex.Replace(text, @"\bthis\.decomp4\(", "this.decomp4.forward("); - text = Regex.Replace(text, @"\bthis\.decomp5\(", "this.decomp5.forward("); - text = Regex.Replace(text, @"\bthis\.conv1\(", "this.conv1.forward("); - text = Regex.Replace(text, @"\bthis\.conv2\(", "this.conv2.forward("); - text = Regex.Replace(text, @"\bthis\.conv3\(", "this.conv3.forward("); - text = Regex.Replace(text, @"\bthis\.conv4\(", "this.conv4.forward("); - text = Regex.Replace(text, @"\bthis\.conv5\(", "this.conv5.forward("); - text = Regex.Replace(text, @"\bthis\.norm1\(", "this.norm1.forward("); - text = Regex.Replace(text, @"\bthis\.norm2\(", "this.norm2.forward("); - text = Regex.Replace(text, @"\bthis\.norm3\(", "this.norm3.forward("); - text = Regex.Replace(text, @"\bthis\.norm4\(", "this.norm4.forward("); - text = Regex.Replace(text, @"\bthis\.norm5\(", "this.norm5.forward("); - - text = Regex.Replace(text, @"\bthis\.downConv\(", "this.downConv.forward("); - text = Regex.Replace(text, @"\bthis\.maxPool\(", "this.maxPool.forward("); - text = Regex.Replace(text, @"\bthis\.avg\(", "this.avg.forward("); - text = Regex.Replace(text, @"\bthis\.layernorm\(", "this.layernorm.forward("); - text = Regex.Replace(text, @"\bthis\.tokenConv\(", "this.tokenConv.forward("); - - text = Regex.Replace(text, @"\bthis\.embedding\(", "this.embedding.forward("); - text = Regex.Replace(text, @"\bthis\.emb\(", "this.emb.forward("); - text = Regex.Replace(text, @"\bthis\.embed\(", "this.embed.forward("); - text = Regex.Replace(text, @"\bthis\.position_embedding\(", "this.position_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.temporal_embedding\(", "this.temporal_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.value_embedding\(", "this.value_embedding.forward("); - - text = Regex.Replace(text, @"\bthis\.month_embed\(", "this.month_embed.forward("); - text = Regex.Replace(text, @"\bthis\.day_embed\(", "this.day_embed.forward("); - text = Regex.Replace(text, @"\bthis\.hour_embed\(", "this.hour_embed.forward("); - text = Regex.Replace(text, @"\bthis\.minute_embed\(", "this.minute_embed.forward("); - text = Regex.Replace(text, @"\bthis\.weekday_embed\(", "this.weekday_embed.forward("); - - text = Regex.Replace(text, @"\bthis\.enc_embedding\(", "this.enc_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.encoder\(", "this.encoder.forward("); - text = Regex.Replace(text, @"\bthis\.dec_embedding\(", "this.dec_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.decoder\(", "this.decoder.forward("); - - text = Regex.Replace(text, @"\bthis\.query_projection\(", "this.query_projection.forward("); - text = Regex.Replace(text, @"\bthis\.key_projection\(", "this.key_projection.forward("); - text = Regex.Replace(text, @"\bthis\.value_projection\(", "this.value_projection.forward("); - text = Regex.Replace(text, @"\bthis\.out_projection\(", "this.out_projection.forward("); - - text = Regex.Replace(text, @"\bthis\.attn\(", "this.attn.forward("); - return text; - } /// /// Replace common Tensor list From 3e914fba181adf3f8c22eec5dff040aa2f6af889 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 2 Mar 2023 16:57:29 +0800 Subject: [PATCH 09/27] fix split paramenter string bug. --- src/Extensions/TorchCs/ClassInfo.cs | 15 ++++++++-- src/Extensions/TorchCs/TorchSharpInfo.cs | 12 ++------ src/Extensions/TorchCs/TorchUtil.cs | 37 ++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 6578eaa..a059e7e 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -341,12 +341,17 @@ public string ReplaceCodes(string code, List fields = null) NewReturnType = NewReturnType.Replace("object", "Tensor"); NewReturnType = NewReturnType.Replace("void", "Tensor"); } - } else if (ReturnType == "void") { + } else if (ReturnType == "void" || ReturnType == "object") { var ms = Regex.Matches(bodyCode, "return ([^;]*);"); var max = 0; foreach (Match item in ms) { - var num = item.Groups[1].Value.Split(','); - max = Math.Max(max, num.Length); + if (item.Groups[1].Value.StartsWith('(')) { + var t= item.Groups[1].Value.Substring(1, item.Groups[1].Value.Length-2); + var ms2 = TorchUtil.splitParamenters(t); + max = Math.Max(max, ms2.Count); + } else { + max = Math.Max(max, 1); + } } if (max == 1) { NewReturnType = "object"; @@ -358,6 +363,10 @@ public string ReplaceCodes(string code, List fields = null) } NewReturnType += ")"; } + if (IsForwardMethod) { + NewReturnType = (NewReturnType?? ReturnType).Replace("object", "Tensor"); + NewReturnType = NewReturnType.Replace("void", "Tensor"); + } } } return $"public {m.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; diff --git a/src/Extensions/TorchCs/TorchSharpInfo.cs b/src/Extensions/TorchCs/TorchSharpInfo.cs index 8079ce5..6a50b9c 100644 --- a/src/Extensions/TorchCs/TorchSharpInfo.cs +++ b/src/Extensions/TorchCs/TorchSharpInfo.cs @@ -73,11 +73,7 @@ public string FindTypeBy_nn(string code, string text) var p = m.Groups[2].Value; var type = FindTypeBy_nn(p, text); if (type != null) { return type; } - var ms2 = Regex.Matches(p.Trim(), @"(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|(?\[)|(?<-BR3>\])|(?\<)|(?<-BR3>\>)|[^\<\>\(\)\{\}\[\]])+?)(,|$)"); - var ps = new List(); - foreach (Match m2 in ms2) { - ps.Add(m2.Groups[1].Value.Trim()); - } + var ps = TorchUtil.splitParamenters(p.Trim()); if (ps.Contains(text) == false) { continue; } var methodName = m.Groups[1].Value; @@ -107,11 +103,7 @@ public string FindTypeBy_torch(string code, string text) var p = m.Groups[2].Value; var type = FindTypeBy_nn(p, text); if (type != null) { return type; } - var ms2 = Regex.Matches(p.Trim(), @"(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|(?\[)|(?<-BR3>\])|(?\<)|(?<-BR3>\>)|[^\<\>\(\)\{\}\[\]])+?)(,|$)"); - var ps = new List(); - foreach (Match m2 in ms2) { - ps.Add(m2.Groups[1].Value.Trim()); - } + var ps = TorchUtil.splitParamenters(p.Trim()); if (ps.Contains(text) == false) { continue; } var methodName = m.Groups[1].Value; diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 4868fcb..453cf67 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -42,7 +42,6 @@ public static void ReplaceFolder(string folder) classNames.Remove("torch"); classNames.Remove("nn"); classNames.Remove("F"); - foreach (var file in files) { var text = File.ReadAllText(file); File.WriteAllText(file, ReplaceCodes(text, classNames)); @@ -626,7 +625,41 @@ private static void getClassName(string text, HashSet classNames) } - + internal static List splitParamenters(string paramenters) + { + bool inText = false; + int bracketLayer = 0; + + List result = new List(); + var index = 0; + string temp = ""; + while (index < paramenters.Length) { + var c = paramenters[index]; + if (inText) { + temp += c; + if (c == '\\') { + index++; + temp += paramenters[index]; + } else if (c == '"') { + inText = false; + } + } else if (c == '(' || c == '{' || c == '[' || c == '<') { + bracketLayer++; + temp += c; + } else if (c == ')' || c == '}' || c == ']' || c == '>') { + bracketLayer--; + temp += c; + } else if (c == ',' && bracketLayer == 0) { + result.Add(temp); + temp = ""; + } else { + temp += c; + } + index++; + } + result.Add(temp); + return result; + } } } From 890bf2e74685cb1dc73566cf28f3785028ffe368 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Fri, 3 Mar 2023 10:15:24 +0800 Subject: [PATCH 10/27] Modify the corresponding method parameter type according to the method of the relevant class in the folder --- src/Extensions/TorchCs/ClassInfo.cs | 195 ++++++++++++++++++++++++---- src/Extensions/TorchCs/TorchUtil.cs | 29 ++++- 2 files changed, 193 insertions(+), 31 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index a059e7e..646dea8 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -1,3 +1,4 @@ +using Pytocs.Core.Types; using System.Text.RegularExpressions; using static System.Net.Mime.MediaTypeNames; @@ -8,8 +9,8 @@ public class ClassFile public string FileName { get; set; } public string Code { get; set; } public List ClassInfos { get; set; } - - public List StaticMethods { get; set; } + public bool HasChange { get; set; } + public bool LastChange { get; set; } public static List LoadFiles(string folder) { @@ -21,10 +22,94 @@ public static List LoadFiles(string folder) classFile.FileName = file; classFile.Code = text; classFile.ClassInfos = ClassInfo.AnalysisCode(text); - classFile.StaticMethods = ClassMethod.AnalysisCodeForStaticMethod(text); + classFile.HasChange = true; + classFile.LastChange = true; + foreach (var item in classFile.ClassInfos) { + item.File = classFile; + } + files.Add(classFile); } return files; } + public Dictionary MatchClassInfo(string code, List classInfos) + { + Dictionary result = new Dictionary(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + if (match.Success) { + var ns = match.Groups[1].Value.Split('.'); + + var ms = Regex.Matches(code, @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"); + foreach (Match m in ms) { + var key = m.Groups[1].Value; + var name = m.Groups[2].Value; + var classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + if (classInfo.File.LastChange) { + result[key] = classInfo; + } + continue; + } + var sp = name.Split("."); + for (int i = 1; i < ns.Length; i++) { + var names = new string[sp.Length + i]; + for (int j = 0; j < i; j++) { + names[j] = ns[j]; + } + for (int j = 0; j < sp.Length; j++) { + names[j + i] = sp[j]; + } + name = string.Join(".", names); + classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + if (classInfo.File.LastChange) { + result[key] = classInfo; + } + break; + } + } + } + } + return result; + } + + public Dictionary MatchClassInfo(string code, List files) + { + Dictionary result = new Dictionary(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + if (match.Success) { + var ns = match.Groups[1].Value.Split('.'); + + var classInfos = new List(); + foreach (var file in files) { classInfos.AddRange(file.ClassInfos); } + + var ms = Regex.Matches(code, @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"); + foreach (Match m in ms) { + var key = m.Groups[1].Value; + var name = m.Groups[2].Value; + var classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + result[key] = classInfo; + continue; + } + List names = new List(); + names.AddRange(ns); + names.Add(""); + names.Add(""); + var sp = name.Split("."); + for (int i = 0; i < sp.Length; i++) { + names[names.Count - sp.Length + i] = sp[i]; + } + name = string.Join(".", names); + classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + result[key] = classInfo; + } + } + } + return result; + } + + } @@ -33,7 +118,7 @@ public class ClassInfo private const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; private const string classRegex2 = @"public class {name}([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; - public ClassFile File { get; set; } + internal ClassFile File { get; set; } public string FullClassName { get; set; } public string ClassName { get; set; } public bool HasForwardMethod { get; set; } @@ -60,6 +145,9 @@ public static List AnalysisCode(string code) classInfo.Methods = ClassMethod.AnalysisCode(bodyCode); classInfo.HasForwardMethod = classInfo.Methods.Any(q => q.MethodName == "forward"); + foreach (var item in classInfo.Methods) { + item.ClassInfo = classInfo; + } classInfos.Add(classInfo); } var fclass = classInfos.Where(q => q.HasForwardMethod).Select(q => q.ClassName).ToList(); @@ -87,15 +175,6 @@ public string AddNewField(string code) return code; } - public string ReplaceNewConstructor(string code, List classInfos) - { - foreach (var classInfo in classInfos) { - code = Regex.Replace(code, $@"\b{classInfo.ClassName}\(", $"new {classInfo.ClassName}("); - } - code = Regex.Replace(code, @"\bnew new ", "new "); - return code; - } - public string ReplaceCodes(string code) { code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { @@ -112,6 +191,38 @@ public string ReplaceCodes(string code) })); return code; } + public string ReplaceMethodParamenterType(string code, Dictionary classInfos) + { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + + Dictionary temp = new Dictionary(); + foreach (var field in Fields) { + if (classInfos.ContainsKey(field.NewType ?? field.Type)) { + temp[field.FieldName] = classInfos[field.NewType ?? field.Type]; + } + } + foreach (var method in Methods) { + bodyCode = method.ReplaceMethodParamenterType(bodyCode, temp); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + return code; + } + + public string GetMethodParamenterType(string methodName, int paramenterIndex) + { + var method = Methods.FirstOrDefault(q => q.MethodName == methodName); + if (method != null) { + if (paramenterIndex < method.Paramenters.Count) { + var p = method.Paramenters[paramenterIndex]; + return p.NewType ?? p.Type; + } + } + return null; + } + public override string ToString() { return $"class: {ClassName}"; @@ -274,7 +385,7 @@ public class ClassMethod private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; // public static int get_q_k(int input_size, int window_size, object stride, object device) - + internal ClassInfo ClassInfo { get; set; } public string MethodName { get; set; } public string ReturnType { get; set; } public string NewReturnType { get; set; } @@ -346,7 +457,7 @@ public string ReplaceCodes(string code, List fields = null) var max = 0; foreach (Match item in ms) { if (item.Groups[1].Value.StartsWith('(')) { - var t= item.Groups[1].Value.Substring(1, item.Groups[1].Value.Length-2); + var t = item.Groups[1].Value.Substring(1, item.Groups[1].Value.Length - 2); var ms2 = TorchUtil.splitParamenters(t); max = Math.Max(max, ms2.Count); } else { @@ -364,7 +475,7 @@ public string ReplaceCodes(string code, List fields = null) NewReturnType += ")"; } if (IsForwardMethod) { - NewReturnType = (NewReturnType?? ReturnType).Replace("object", "Tensor"); + NewReturnType = (NewReturnType ?? ReturnType).Replace("object", "Tensor"); NewReturnType = NewReturnType.Replace("void", "Tensor"); } } @@ -373,12 +484,46 @@ public string ReplaceCodes(string code, List fields = null) })); return code; } + + public string ReplaceMethodParamenterType(string code, Dictionary classInfos) + { + var paramenters = Paramenters.Where(q => q.NewType == null && q.Type == "object").ToList(); + if (paramenters.Count == 0) { return code; } + + code = Regex.Replace(code, methodRegex2.Replace("{name}", MethodName), new MatchEvaluator(m1 => { + var ParamenterCode = m1.Groups[3].Value; + var bodyCode = m1.Groups[4].Value; + + var reg = @"\bthis\.([a-zA-Z_][a-zA-Z_0-9]+)\.([^\(\.]+)\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(bodyCode, reg); + foreach (Match m in ms) { + var fieldName = m.Groups[1].Value; + if (classInfos.ContainsKey(fieldName) == false) continue; + var name = m.Groups[2].Value; + var ps = m.Groups[3].Value; + var ps2 = TorchUtil.splitParamenters(ps); + for (int i = paramenters.Count - 1; i >= 0; i--) { + var paramenter = paramenters[i]; + var index = ps2.IndexOf(paramenter.ParamenterName); + if (index >= 0) { + var type = classInfos[fieldName].GetMethodParamenterType(name, index); + if (type != null && type != "object") { + paramenter.NewType = type; + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + this.ClassInfo.File.HasChange = true; + paramenters.RemoveAt(i); + } + } + } + } + return $"public {m1.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; + })); + return code; + } + public override string ToString() { - if (NewReturnType != null) { - return $"method: {NewReturnType} {MethodName}"; - } - return $"method: {ReturnType} {MethodName}"; + return $"method: {NewReturnType ?? ReturnType} {MethodName}"; } } @@ -475,10 +620,7 @@ public string ReplaceCodes(string code) public override string ToString() { - if (NewType != null) { - return $"paramenter: {NewType} {ParamenterName}"; - } - return $"paramenter: {Type} {ParamenterName}"; + return $"paramenter: {NewType ?? Type} {ParamenterName}"; } } @@ -550,10 +692,7 @@ public string ReplaceCodes(string code) public override string ToString() { - if (NewType != null) { - return $"variable: {NewType} {VariableName}"; - } - return $"variable: {Type} {VariableName}"; + return $"variable: {NewType??Type} {VariableName}"; } } diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 453cf67..f4fb110 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -46,6 +46,28 @@ public static void ReplaceFolder(string folder) var text = File.ReadAllText(file); File.WriteAllText(file, ReplaceCodes(text, classNames)); } + + var fileInfos = ClassFile.LoadFiles(folder); + var classInfos = new List(); + foreach (var file in fileInfos) { classInfos.AddRange(file.ClassInfos); } + bool IsChange; + do { + IsChange = false; + foreach (var fileInfo in fileInfos) { + fileInfo.LastChange = fileInfo.HasChange; + fileInfo.HasChange = false; + } + foreach (var fileInfo in fileInfos) { + var dict = fileInfo.MatchClassInfo(fileInfo.Code, classInfos); + foreach (var classInfo in fileInfo.ClassInfos) { + fileInfo.Code = classInfo.ReplaceMethodParamenterType(fileInfo.Code, dict); + } + if (fileInfo.HasChange) { + File.WriteAllText(fileInfo.FileName, fileInfo.Code); + IsChange = true; + } + } + } while (IsChange); } /// /// Convert file, Replace grammar rules @@ -85,6 +107,7 @@ public static string ReplaceCodes(string text, HashSet classNames = null text = classInfo.AddNewField(text); // Add missing fields text = classInfo.ReplaceCodes(text); } + // One file is a static class. There are only static methods in the static class, so I will deal with the static methods in the file. var sss = ClassMethod.AnalysisCodeForStaticMethod(text); foreach (var item in sss) { text = item.ReplaceCodes(text); @@ -628,7 +651,7 @@ private static void getClassName(string text, HashSet classNames) internal static List splitParamenters(string paramenters) { bool inText = false; - int bracketLayer = 0; + int bracketLayer = 0; // List result = new List(); var index = 0; @@ -650,14 +673,14 @@ internal static List splitParamenters(string paramenters) bracketLayer--; temp += c; } else if (c == ',' && bracketLayer == 0) { - result.Add(temp); + result.Add(temp.Trim()); temp = ""; } else { temp += c; } index++; } - result.Add(temp); + result.Add(temp.Trim()); return result; } From ed3c5ba22f9d2853b4a9edfc7a3176a5c7de4b8d Mon Sep 17 00:00:00 2001 From: linzhijun Date: Fri, 3 Mar 2023 13:27:44 +0800 Subject: [PATCH 11/27] Add Run Code --- src/Extensions/TorchCs/ClassInfo.cs | 4 +- src/Extensions/TorchCs/Program.cs | 68 +++++++++++++++++++++++++++ src/Extensions/TorchCs/TorchCs.csproj | 1 + src/Extensions/TorchCs/TorchUtil.cs | 16 ++++--- 4 files changed, 80 insertions(+), 9 deletions(-) create mode 100644 src/Extensions/TorchCs/Program.cs diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 646dea8..ebd2438 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -330,7 +330,7 @@ public static List AnalysisCode(string code) field1.NewType = "bool"; } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(code, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { field1.NewType = "string"; - } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) (\d+\.\d+|\d+(\.\d+)?[Ee])")) { field1.NewType = "double"; } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { field1.NewType = "int"; @@ -578,7 +578,7 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.NewType = "bool"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.(split|startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { classMethodParamenter.NewType = "string"; - } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) (\d+\.\d+|\d+(\.\d+)?[Ee])")) { classMethodParamenter.NewType = "doulbe"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { classMethodParamenter.NewType = "int"; diff --git a/src/Extensions/TorchCs/Program.cs b/src/Extensions/TorchCs/Program.cs new file mode 100644 index 0000000..b63ce20 --- /dev/null +++ b/src/Extensions/TorchCs/Program.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp.Data; + +namespace TorchCs +{ + public class Program + { + private const string usage = +@"Usage: + TorchCs [options] + +Options: + -d, --dir Convert all files in the directory + -n, --netstandard Generate netstandard.cs file, The parameter has -d or -dir is valid. +"; + static void Main(string[] args) + { + var options = ParseOptions(args); + if (options.Count == 0) { + Console.WriteLine(usage); + return; + } + if (options.ContainsKey("--dir")) { + Console.WriteLine("Conversion directory:" + options["--dir"].ToString()); + TorchUtil.ReplaceFolder(options["--dir"].ToString(), options.ContainsKey("--netstandard")); + } else { + foreach (var item in (List)options[""]) { + Console.WriteLine("Conversion file:" + item); + TorchUtil.ReplaceFile(item); + } + } + if (options.ContainsKey("--netstandard")) { + if (options.ContainsKey("--dir")) { + Console.WriteLine("Generate netstandard.cs file"); + TorchUtil.CreateNetstandardCode(Path.GetDirectoryName(options["--dir"].ToString())); + } + } + Console.WriteLine("Conversion completed!"); + } + + private static IDictionary ParseOptions(string[] args) + { + var result = new Dictionary(); + var files = new List(); + + int index = 0; + while (index < args.Length) { + var arg = args[index++]; + if (!arg.StartsWith('-')) { + files.Add(arg); + + } else if (arg == "-d" || arg == "--dir") { + result["--dir"] = args[index++]; + } else if (arg == "-n" || arg == "--netstandard") { + result["--netstandard"] = true; + } + } + result[""] = files; + return result; + } + + + } +} diff --git a/src/Extensions/TorchCs/TorchCs.csproj b/src/Extensions/TorchCs/TorchCs.csproj index 8a27351..c9d78e1 100644 --- a/src/Extensions/TorchCs/TorchCs.csproj +++ b/src/Extensions/TorchCs/TorchCs.csproj @@ -1,6 +1,7 @@  + Exe net6.0 enable disable diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index f4fb110..467b14c 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -31,7 +31,7 @@ public class TorchUtil /// Convert all *.py.cs files in the folder ,Replace grammar rules /// /// - public static void ReplaceFolder(string folder) + public static void ReplaceFolder(string folder, bool replaceStringToNetstandard = true) { var files = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); HashSet classNames = new HashSet(); @@ -44,7 +44,7 @@ public static void ReplaceFolder(string folder) classNames.Remove("F"); foreach (var file in files) { var text = File.ReadAllText(file); - File.WriteAllText(file, ReplaceCodes(text, classNames)); + File.WriteAllText(file, ReplaceCodes(text, classNames, replaceStringToNetstandard)); } var fileInfos = ClassFile.LoadFiles(folder); @@ -73,24 +73,24 @@ public static void ReplaceFolder(string folder) /// Convert file, Replace grammar rules ///
/// - public static void ReplaceFile(string file) + public static void ReplaceFile(string file, bool replaceStringToNetstandard = false) { var text = File.ReadAllText(file); - File.WriteAllText(file, ReplaceCodes(text)); + File.WriteAllText(file, ReplaceCodes(text, null, replaceStringToNetstandard)); } /// /// Convert code, Replace grammar rules /// /// /// - public static string ReplaceCodes(string text, HashSet classNames = null) + public static string ReplaceCodes(string text, HashSet classNames = null, bool replaceToNetstandard = true) { // replace 'self' to 'this' text = Regex.Replace(text, @"\bself\.", "this."); // replace field type text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = ""\S+?""[,;)])", "string $2"); text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+[,;)])", "int $2"); - text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+\.\d+[,;)])", "double $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (\d+\.\d+|\d+(\.\d+)?[Ee]-?\d+)[,;)])", "double $2"); text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (true|false)[,;)])", "bool $2"); text = Regex.Replace(text, @"\bvoid ([a-zA-Z_][a-zA-Z0-9_]*[ ,);])", "object $1"); // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' @@ -122,7 +122,9 @@ public static string ReplaceCodes(string text, HashSet classNames = null text = replaceTensorList(text); text = replaceIsType(text); - text = replaceStringToNetstandard(text); + if (replaceToNetstandard) { + text = replaceStringToNetstandard(text); + } text = text.Replace("using (var torch.no_grad())", "using (var _no_grad= torch.no_grad())"); text = text.Replace("using (var torch.cuda.amp.autocast())", "using (var _autocast= torch.cuda.amp.autocast())"); From 08fe4163d95bda16829ec4ebd3b453c0f1ad9040 Mon Sep 17 00:00:00 2001 From: toolgood Date: Fri, 3 Mar 2023 21:10:16 +0800 Subject: [PATCH 12/27] add 'remove' method --- src/Extensions/TorchCs/Resources/netstandard.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index 51901fb..c5412c6 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -113,7 +113,10 @@ public static void append(this ICollection list, T obj) { list.Add(obj); } - + public static void remove(this ICollection list, T obj) + { + list.Remove(obj); + } public static ICollection keys(this IDictionary dict) { return dict.Keys; From 40f91fcbc62683563f76e33fb022f3109f8051d1 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Sat, 4 Mar 2023 09:27:40 +0800 Subject: [PATCH 13/27] add python list methods --- .../TorchCs/Resources/netstandard.cs | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index c5412c6..464d54b 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -16,6 +16,7 @@ using System.Runtime.CompilerServices; using System.Security.Cryptography; using static TorchSharp.torch; +using System.Linq; namespace System { @@ -117,6 +118,49 @@ public static void remove(this ICollection list, T obj) { list.Remove(obj); } + public static void extend(this ICollection list, params T[] objs) + { + foreach (var obj in objs) { + list.Add(obj); + } + } + public static int count(this ICollection list, T obj) + { + return list.Where(q => q.Equals(obj)).Count(); + } + public static int index(this ICollection list, T obj) + { + var index = -1; + foreach (var item in list) { + index++; + if (item.Equals(obj)) { + return index; + } + } + return -1; + } + public static void reverse(this ICollection list) + { + list = list.Reverse().ToList(); + } + public static void insert(this IList list, int index, T obj) + { + list.Insert(index, obj); + } + public static T pop(this IList list) + { + var last = list[list.Count - 1]; + list.RemoveAt(list.Count - 1); + return last; + } + public static T pop(this IList list) + { + var last = list[list.Count - 1]; + list.RemoveAt(list.Count - 1); + return last; + } + + public static ICollection keys(this IDictionary dict) { return dict.Keys; From f720d31f897b8c0e77be7172a8712e8f8c9b4024 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Sat, 4 Mar 2023 09:44:30 +0800 Subject: [PATCH 14/27] add python Dictionary method. --- .../TorchCs/Resources/netstandard.cs | 65 +++++++++++++++++-- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index 464d54b..9bf5146 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -153,18 +153,71 @@ public static T pop(this IList list) list.RemoveAt(list.Count - 1); return last; } - public static T pop(this IList list) - { - var last = list[list.Count - 1]; - list.RemoveAt(list.Count - 1); - return last; - } public static ICollection keys(this IDictionary dict) { return dict.Keys; } + public static ICollection values(this IDictionary dict) + { + return dict.Values; + } + public static void clear(this IDictionary dict) + { + dict.Clear(); + } + public static T2 get(this IDictionary dict, T1 key) + { + if (dict.TryGetValue(key, out T2 result)) { + return result; + } + return default(T2); + } + public static T2 get(this IDictionary dict, T1 key, T2 def) + { + if (dict.TryGetValue(key, out T2 result)) { + return result; + } + return def; + } + public static bool has_key(this IDictionary dict, T1 key) + { + return (dict.ContainsKey(key)); + } + public static T2 pop(this IDictionary dict, T1 key) + { + if (dict.TryGetValue(key, out T2 result)) { + dict.Remove(key); + return result; + } + return default(T2); + } + public static T2 pop(this IDictionary dict, T1 key, T2 def) + { + if (dict.TryGetValue(key, out T2 result)) { + dict.Remove(key); + return result; + } + return def; + } + public static (T1, T2) popitem(this IDictionary dict) + { + T1 key = default(T1); + T2 val = default(T2); + foreach (var item in dict) { + key = item.Key; + val = item.Value; + } + if (dict.ContainsKey(key)) { + dict.Remove(key); + } + return (key, val); + } + + + + /// /// Simplify code, similar to python syntax /// python code : B, L = queries.shape From c7954286591e8ab50962e93c2a0e6afc9a1d366a Mon Sep 17 00:00:00 2001 From: linzhijun Date: Sat, 4 Mar 2023 16:23:32 +0800 Subject: [PATCH 15/27] add python copy method --- .../TorchCs/Resources/netstandard.cs | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index 9bf5146..d43e93c 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -109,6 +109,10 @@ public static string rstrip(this string str) return str.TrimEnd(); } + public static T copy(this T obj) where T : ICloneable + { + return (T)obj.Clone(); + } public static void append(this ICollection list, T obj) { @@ -153,6 +157,18 @@ public static T pop(this IList list) list.RemoveAt(list.Count - 1); return last; } + public static ICollection copy(this ICollection list) + { + var newObj = new List(); + newObj.AddRange(list); + return newObj; + } + public static List copy(this List list) + { + var newObj = new List(); + newObj.AddRange(list); + return newObj; + } public static ICollection keys(this IDictionary dict) @@ -215,6 +231,22 @@ public static (T1, T2) popitem(this IDictionary dict) return (key, val); } + public static IDictionary copy(this IDictionary dict) + { + Dictionary copy = new Dictionary(); + foreach (var item in dict) { + copy[item.Key] = item.Value; + } + return copy; + } + public static Dictionary copy(this Dictionary dict) + { + Dictionary copy = new Dictionary(); + foreach (var item in dict) { + copy[item.Key] = item.Value; + } + return copy; + } From c24954c02ae8bbfedd56715d45f26a7019354b04 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Sat, 4 Mar 2023 16:27:16 +0800 Subject: [PATCH 16/27] add license info --- src/Extensions/TorchCs/ClassInfo.cs | 15 +++++++++++++++ src/Extensions/TorchCs/Program.cs | 17 ++++++++++++++++- src/Extensions/TorchCs/TorchSharpInfo.cs | 17 ++++++++++++++++- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index ebd2438..e362533 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -1,3 +1,18 @@ +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion using Pytocs.Core.Types; using System.Text.RegularExpressions; using static System.Net.Mime.MediaTypeNames; diff --git a/src/Extensions/TorchCs/Program.cs b/src/Extensions/TorchCs/Program.cs index b63ce20..c3cfb2d 100644 --- a/src/Extensions/TorchCs/Program.cs +++ b/src/Extensions/TorchCs/Program.cs @@ -1,4 +1,19 @@ -using System; +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/Extensions/TorchCs/TorchSharpInfo.cs b/src/Extensions/TorchCs/TorchSharpInfo.cs index 6a50b9c..eed8e06 100644 --- a/src/Extensions/TorchCs/TorchSharpInfo.cs +++ b/src/Extensions/TorchCs/TorchSharpInfo.cs @@ -1,4 +1,19 @@ -using System.Reflection; +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion +using System.Reflection; using System.Text.RegularExpressions; using TorchSharp; From a1fbb96b186f0e4fd1d2080874febe814ec4f830 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Mon, 6 Mar 2023 10:05:01 +0800 Subject: [PATCH 17/27] Add more default type judgments --- src/Extensions/TorchCs/ClassInfo.cs | 73 ++++++++++++++++++++++------- src/Extensions/TorchCs/TorchUtil.cs | 54 ++++++++++++++++++++- 2 files changed, 110 insertions(+), 17 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index e362533..2379ef5 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -13,9 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #endregion -using Pytocs.Core.Types; using System.Text.RegularExpressions; -using static System.Net.Mime.MediaTypeNames; namespace TorchCs { @@ -263,8 +261,6 @@ public static ClassConstructor AnalysisCode(string code, string className) return classConstructor; } - - public string ReplaceCodes(string code) { code = Regex.Replace(code, constructorRegex.Replace("{name}", ClassName), new MatchEvaluator(m => { @@ -294,7 +290,15 @@ public static List AnalysisCode(string code) { List classFields = new List(); HashSet fields = new HashSet(); - var ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_][a-zA-Z0-9_]*);"); + var ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_@][a-zA-Z0-9_]*);"); + foreach (Match match in ms) { + ClassField field = new ClassField(); + field.Type = match.Groups[1].Value; + field.FieldName = match.Groups[2].Value; + fields.Add(field.FieldName); + classFields.Add(field); + } + ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_@][a-zA-Z0-9_]*) ="); foreach (Match match in ms) { ClassField field = new ClassField(); field.Type = match.Groups[1].Value; @@ -302,7 +306,7 @@ public static List AnalysisCode(string code) fields.Add(field.FieldName); classFields.Add(field); } - ms = Regex.Matches(code, @"\bthis\.([a-zA-Z_][a-zA-Z0-9_]*)[ \t\r\n,;)\[]"); + ms = Regex.Matches(code, @"\bthis\.([a-zA-Z_@][a-zA-Z0-9_]*)[ \t\r\n,;)\[]"); foreach (Match m in ms) { if (fields.Add(m.Groups[1].Value)) { ClassField field = new ClassField(); @@ -351,11 +355,19 @@ public static List AnalysisCode(string code) field1.NewType = "int"; } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { field1.NewType = "Tensor"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(dropout|.*_dropout)$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$")) { + field1.NewType = "OptimizerHelper"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$")) { + field1.NewType = "LRScheduler"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$")) { + field1.NewType = "Dataset"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataloader|loader|.*_loader)$")) { + field1.NewType = "DataLoader"; + } else if (field1.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { field1.NewType = "double"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|epoch|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { + } else if (field1.Type == "object" && TorchUtil.isIntTypeByName(name)) { field1.NewType = "int"; - } else if (field1.Type == "object" && Regex.IsMatch(name, "^(name|path|dir|.*(_path|_name|_dir))$")) { + } else if (field1.Type == "object" && TorchUtil.isStringTypeByName(name)) { field1.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; @@ -481,6 +493,21 @@ public string ReplaceCodes(string code, List fields = null) } if (max == 1) { NewReturnType = "object"; + var f = ms[0].Value; + if (f.StartsWith("this.")) { + if (fields != null) { + f = f.Substring(5); + var p = fields.FirstOrDefault(q => q.FieldName == f); + if (p != null) { + NewReturnType = p.NewType ?? p.Type; + } + } + } else { + var p = Paramenters.FirstOrDefault(q => q.ParamenterName == f); + if (p != null) { + NewReturnType = p.NewType ?? p.Type; + } + } } else if (max > 1) { NewReturnType = "("; for (int i = 0; i < max; i++) { @@ -558,7 +585,7 @@ public static List AnalysisCode(string code, string text) List classMethodParamenters = new List(); if (string.IsNullOrEmpty(code)) { return classMethodParamenters; } - var strs = Regex.Matches(code, "(.*?) ([a-zA-Z_][a-zA-Z_0-9]*)( = ([^,]+))?(,|$)"); + var strs = Regex.Matches(code, "(.*?) ([a-zA-Z_@][a-zA-Z_0-9]*)( = ([^,]+))?(,|$)"); foreach (Match str in strs) { ClassMethodParamenter classMethodParamenter = new ClassMethodParamenter(); @@ -599,15 +626,29 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.NewType = "int"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\[[^\]]*?TensorIndex\.")) { classMethodParamenter.NewType = "Tensor"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { + classMethodParamenter.NewType = "Tensor"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{fieldsRegex}[ ,;)\[]")) { classMethodParamenter.NewType = "Tensor"; - } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(label|pred|preds|target|x_enc|x_mark_enc|x_dec|x_mark_dec)$")) { classMethodParamenter.NewType = "Tensor"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(dropout|.*_dropout)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$")) { + classMethodParamenter.NewType = "Dataset"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(loader|.*_loader)$")) { + classMethodParamenter.NewType = "DataLoader"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$")) { + classMethodParamenter.NewType = "OptimizerHelper"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$")) { + classMethodParamenter.NewType = "LRScheduler"; + } else if (classMethodParamenter.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { classMethodParamenter.NewType = "double"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(channels|index|length|step|epoch|(num_|n_).*|.*(_len|_in|_model|_out|_channels|_size|_dims|_count|_index))$")) { - classMethodParamenter.NewType = "int"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, "^(name|path|dir|.*(_path|_name|_dir))$")) { + } else if (classMethodParamenter.Type == "object" && TorchUtil.isIntTypeByName(name)) { + if (classMethodParamenter.DefaultValue == "null") { + classMethodParamenter.NewType = "int?"; + } else { + classMethodParamenter.NewType = "int"; + } + } else if (classMethodParamenter.Type == "object" && TorchUtil.isStringTypeByName(name)) { classMethodParamenter.NewType = "string"; //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { // classMethodParamenter.NewType = "double"; @@ -707,7 +748,7 @@ public string ReplaceCodes(string code) public override string ToString() { - return $"variable: {NewType??Type} {VariableName}"; + return $"variable: {NewType ?? Type} {VariableName}"; } } diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 467b14c..ba595a2 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -649,7 +649,6 @@ private static void getClassName(string text, HashSet classNames) } } - internal static List splitParamenters(string paramenters) { bool inText = false; @@ -686,5 +685,58 @@ internal static List splitParamenters(string paramenters) return result; } + /// + /// Judge whether it is a Double type according to the parameter name + /// + /// + /// + internal static bool isDoubleTypeByName(string name) + { + if (Regex.IsMatch(name, "^(dropout|lr|lr_step|factor|lr_max|num)$")) { + return true; + } + if (Regex.IsMatch(name, "^.*(_dropout|_factor|_momentum|_lr|_min|_max)$")) { + return true; + } + return false; + } + /// + /// Judge whether it is a Int type according to the parameter name + /// + /// + /// + internal static bool isIntTypeByName(string name) + { + if (Regex.IsMatch(name, "^(channels|index|length|step|epoch|stride|total_steps|d_k|d_v|d_q)$")) { + return true; + } + if (Regex.IsMatch(name, "^.*(_len|_length|_in|_model|_out|_channels|_size|_dims|_count|_index|_epoch|_num|_side)$")) { + return true; + } + if (Regex.IsMatch(name, "^(num_|n_).*$")) { + return true; + } + if (Regex.IsMatch(name, "^.*(_num_|_len_).*$")) { + return true; + } + return false; + } + /// + /// Judge whether it is a String type according to the parameter name + /// + /// + /// + internal static bool isStringTypeByName(string name) + { + if (Regex.IsMatch(name, "^(name|path|dir|file|device)$")) { + return true; + } + if (Regex.IsMatch(name, "^.*(_path|_name|_dir|file|_str|_txt)$")) { + return true; + } + return false; + } + + } } From bddecfd8c90d5aa48cbf6cf026887f32592ec3a2 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Mon, 6 Mar 2023 10:18:21 +0800 Subject: [PATCH 18/27] add 'RegexOptions.IgnoreCase' and remark --- src/Extensions/TorchCs/ClassInfo.cs | 18 +++++++------- src/Extensions/TorchCs/TorchUtil.cs | 37 +++++++++++++++++++---------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 2379ef5..e7b1fdb 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -355,13 +355,13 @@ public static List AnalysisCode(string code) field1.NewType = "int"; } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { field1.NewType = "Tensor"; - } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$", RegexOptions.IgnoreCase)) { field1.NewType = "OptimizerHelper"; - } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$", RegexOptions.IgnoreCase)) { field1.NewType = "LRScheduler"; - } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$", RegexOptions.IgnoreCase)) { field1.NewType = "Dataset"; - } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataloader|loader|.*_loader)$")) { + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataloader|loader|.*_loader)$", RegexOptions.IgnoreCase)) { field1.NewType = "DataLoader"; } else if (field1.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { field1.NewType = "double"; @@ -630,15 +630,15 @@ public static List AnalysisCode(string code, string text) classMethodParamenter.NewType = "Tensor"; } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{fieldsRegex}[ ,;)\[]")) { classMethodParamenter.NewType = "Tensor"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(label|pred|preds|target|x_enc|x_mark_enc|x_dec|x_mark_dec)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(label|pred|preds|target|targets|x_enc|x_mark_enc|x_dec|x_mark_dec)$", RegexOptions.IgnoreCase)) { classMethodParamenter.NewType = "Tensor"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$", RegexOptions.IgnoreCase)) { classMethodParamenter.NewType = "Dataset"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(loader|.*_loader)$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(loader|.*_loader)$", RegexOptions.IgnoreCase)) { classMethodParamenter.NewType = "DataLoader"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$", RegexOptions.IgnoreCase)) { classMethodParamenter.NewType = "OptimizerHelper"; - } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$")) { + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$", RegexOptions.IgnoreCase)) { classMethodParamenter.NewType = "LRScheduler"; } else if (classMethodParamenter.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { classMethodParamenter.NewType = "double"; diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index ba595a2..482dcd7 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -619,11 +619,16 @@ private static string replaceStringToNetstandard(string text) return text; } - + /// + /// Add 'new' word to class initialization + /// + /// + /// + /// private static string replaceNewClass(string text, HashSet classNames) { if (classNames == null) { return text; } - const string classRegex = @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"; + const string classRegex = @"using ([a-zA-Z_@][a-zA-Z0-9_]*) = ([a-zA-Z_@][a-zA-Z0-9_.@]*);"; List names = new List(); var ms = Regex.Matches(text, classRegex); @@ -639,7 +644,11 @@ private static string replaceNewClass(string text, HashSet classNames) text = Regex.Replace(text, @"\bnew new ", "new "); return text; } - + /// + /// Get all type names, excluding static classes + /// + /// + /// private static void getClassName(string text, HashSet classNames) { const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)"; @@ -648,7 +657,11 @@ private static void getClassName(string text, HashSet classNames) classNames.Add(m.Groups[1].Value); } } - + /// + /// Split parameter, applicable to method definition and method call + /// + /// + /// internal static List splitParamenters(string paramenters) { bool inText = false; @@ -692,10 +705,10 @@ internal static List splitParamenters(string paramenters) /// internal static bool isDoubleTypeByName(string name) { - if (Regex.IsMatch(name, "^(dropout|lr|lr_step|factor|lr_max|num)$")) { + if (Regex.IsMatch(name, "^(dropout|lr|lr_step|factor|lr_max|num)$", RegexOptions.IgnoreCase)) { return true; } - if (Regex.IsMatch(name, "^.*(_dropout|_factor|_momentum|_lr|_min|_max)$")) { + if (Regex.IsMatch(name, "^.*(_dropout|_factor|_momentum|_lr|_min|_max)$", RegexOptions.IgnoreCase)) { return true; } return false; @@ -707,16 +720,16 @@ internal static bool isDoubleTypeByName(string name) /// internal static bool isIntTypeByName(string name) { - if (Regex.IsMatch(name, "^(channels|index|length|step|epoch|stride|total_steps|d_k|d_v|d_q)$")) { + if (Regex.IsMatch(name, "^(channels|index|length|step|epoch|stride|total_steps|d_k|d_v|d_q)$", RegexOptions.IgnoreCase)) { return true; } - if (Regex.IsMatch(name, "^.*(_len|_length|_in|_model|_out|_channels|_size|_dims|_count|_index|_epoch|_num|_side)$")) { + if (Regex.IsMatch(name, "^.*(_len|_length|_in|_model|_out|_channels|_size|_dims|_count|_index|_epoch|_num|_side)$", RegexOptions.IgnoreCase)) { return true; } - if (Regex.IsMatch(name, "^(num_|n_).*$")) { + if (Regex.IsMatch(name, "^(num_|n_).*$", RegexOptions.IgnoreCase)) { return true; } - if (Regex.IsMatch(name, "^.*(_num_|_len_).*$")) { + if (Regex.IsMatch(name, "^.*(_num_|_len_).*$", RegexOptions.IgnoreCase)) { return true; } return false; @@ -728,10 +741,10 @@ internal static bool isIntTypeByName(string name) /// internal static bool isStringTypeByName(string name) { - if (Regex.IsMatch(name, "^(name|path|dir|file|device)$")) { + if (Regex.IsMatch(name, "^(name|path|dir|file|device)$", RegexOptions.IgnoreCase)) { return true; } - if (Regex.IsMatch(name, "^.*(_path|_name|_dir|file|_str|_txt)$")) { + if (Regex.IsMatch(name, "^.*(_path|_name|_dir|file|_str|_txt)$", RegexOptions.IgnoreCase)) { return true; } return false; From d62576913069faaac6d79ce95f88d6728be6a1c2 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Mon, 6 Mar 2023 15:19:58 +0800 Subject: [PATCH 19/27] replace throw new ValueError --- src/Extensions/TorchCs/TorchUtil.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 482dcd7..1c544b9 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -95,6 +95,8 @@ public static string ReplaceCodes(string text, HashSet classNames = null text = Regex.Replace(text, @"\bvoid ([a-zA-Z_][a-zA-Z0-9_]*[ ,);])", "object $1"); // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' text = Regex.Replace(text, @"([a-zA-Z_0-9]+) = (\1 \|\| (.*?;))", "$1 = $1 ?? $3 //$2"); + // replace throw new ValueError + text = text.Replace("throw new ValueError(", "throw new ArgumentException("); text = replaceNamespace(text); text = replaceConstructor(text); From a517174523eb8ba0adfef56ac70709e94fe3978f Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 7 Mar 2023 08:46:52 +0800 Subject: [PATCH 20/27] Add method regular matching with ReturnType 'object []' --- src/Extensions/TorchCs/ClassInfo.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index e7b1fdb..907fbc4 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -408,9 +408,9 @@ public override string ToString() public class ClassMethod { - private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; - private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; - private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; // public static int get_q_k(int input_size, int window_size, object stride, object device) internal ClassInfo ClassInfo { get; set; } public string MethodName { get; set; } @@ -479,7 +479,7 @@ public string ReplaceCodes(string code, List fields = null) NewReturnType = NewReturnType.Replace("object", "Tensor"); NewReturnType = NewReturnType.Replace("void", "Tensor"); } - } else if (ReturnType == "void" || ReturnType == "object") { + } else if (ReturnType == "void" || ReturnType == "object" || ReturnType == "object[]") { var ms = Regex.Matches(bodyCode, "return ([^;]*);"); var max = 0; foreach (Match item in ms) { From fef4a43ac9c761fe71ac746d0ed97b8c085d75b6 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 7 Mar 2023 09:07:11 +0800 Subject: [PATCH 21/27] Add variable name matching in method --- src/Extensions/TorchCs/ClassInfo.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs index 907fbc4..c76a703 100644 --- a/src/Extensions/TorchCs/ClassInfo.cs +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -290,7 +290,7 @@ public static List AnalysisCode(string code) { List classFields = new List(); HashSet fields = new HashSet(); - var ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_@][a-zA-Z0-9_]*);"); + var ms = Regex.Matches(code, @"public ([a-zA-Z_][a-zA-Z0-9_\<\>\[\]\?]*) ([a-zA-Z_@][a-zA-Z0-9_]*);"); foreach (Match match in ms) { ClassField field = new ClassField(); field.Type = match.Groups[1].Value; @@ -298,7 +298,7 @@ public static List AnalysisCode(string code) fields.Add(field.FieldName); classFields.Add(field); } - ms = Regex.Matches(code, "public ([a-zA-Z_][a-zA-Z0-9_<>]*) ([a-zA-Z_@][a-zA-Z0-9_]*) ="); + ms = Regex.Matches(code, @"public ([a-zA-Z_][a-zA-Z0-9_\<\>\[\]\?]*) ([a-zA-Z_@][a-zA-Z0-9_]*) ="); foreach (Match match in ms) { ClassField field = new ClassField(); field.Type = match.Groups[1].Value; @@ -698,7 +698,7 @@ public static List AnalysisCode(string code, List\[\]]*) ([a-zA-Z_][a-zA-Z0-9_]*)(;| = )"); if (m.Success) { if (names.Add(m.Groups[1].Value)) { ClassMethodVariable classMethodVariable = new ClassMethodVariable(); @@ -709,7 +709,7 @@ public static List AnalysisCode(string code, List @@ -91,13 +97,16 @@ public static void Main(string[] argv) logger.Error("Unable to load {0}.", path); continue; } + string outputPath = Path.ChangeExtension(path, ".py.cs").Replace(startDir, outputDir); + Directory.CreateDirectory(Path.GetDirectoryName(outputPath)!); + xlator.TranslateModuleStatements( module.Body.Statements, types, - Path.ChangeExtension(path, ".py.cs")); + outputPath); } }); - } + } else { if (!options.TryGetValue("", out var oFiles) || @@ -134,13 +143,14 @@ public static void Main(string[] argv) } } - private static IDictionary ParseOptions(string[] args) + private static IDictionary ParseOptions(string[] args) { var result = new Dictionary(); var files = new List(); - for (int i = 0; i < args.Length; ++i) + int i = 0; + while (i < args.Length) { - var arg = args[i]; + var arg = args[i++]; if (!arg.StartsWith('-')) { files = args.Skip(i).ToList(); @@ -161,15 +171,25 @@ private static IDictionary ParseOptions(string[] args) var dirname = "."; if (i < args.Length - 1) { - if (!args[i + 1].StartsWith('-')) + if (!args[i].StartsWith('-')) { - ++i; - dirname = args[i]; + dirname = args[i++]; } - break; } result["--recursive"] = dirname; break; + case "-o": + case "--output": + var dirname2 = "."; + if (i < args.Length - 1) + { + if (!args[i].StartsWith('-')) + { + dirname2 = args[i++]; + } + } + result["--output"] = dirname2; + break; } } result[""] = files; From c6cef913fe256391682193c87b8bac654e884fc6 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Tue, 7 Mar 2023 10:30:37 +0800 Subject: [PATCH 23/27] fix cli bug --- src/Cli/Program.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Cli/Program.cs b/src/Cli/Program.cs index 4eadb87..1b5eb41 100644 --- a/src/Cli/Program.cs +++ b/src/Cli/Program.cs @@ -65,7 +65,7 @@ public static void Main(string[] argv) var startDir = (string) oStartDir; if (startDir == "." || startDir == "./" || startDir == ".\\") startDir = Directory.GetCurrentDirectory(); - startDir=Path.GetFullPath(startDir); + startDir = Path.GetFullPath(startDir); typeAnalysis.Analyze(startDir); typeAnalysis.Finish(); var types = new TypeReferenceTranslator(typeAnalysis.BuildTypeDictionary()); @@ -74,7 +74,7 @@ public static void Main(string[] argv) //{ // Console.WriteLine("{0}: {1} {2}", de.Key, de.Key.Start, de.Value); //} - var outputDir = options.ContainsKey("--output") ? (string)options["--output"] : startDir; + var outputDir = options.ContainsKey("--output") ? (string) options["--output"] : startDir; if (outputDir == "." || outputDir == "./" || outputDir == ".\\") outputDir = Directory.GetCurrentDirectory(); outputDir = Path.GetFullPath(outputDir); @@ -169,7 +169,7 @@ private static IDictionary ParseOptions(string[] args) case "-r": case "--recursive": var dirname = "."; - if (i < args.Length - 1) + if (i < args.Length) { if (!args[i].StartsWith('-')) { @@ -181,7 +181,7 @@ private static IDictionary ParseOptions(string[] args) case "-o": case "--output": var dirname2 = "."; - if (i < args.Length - 1) + if (i < args.Length) { if (!args[i].StartsWith('-')) { From 21b779bc380c18b7944b7a9abc4dfeac3333cd89 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 9 Mar 2023 09:31:43 +0800 Subject: [PATCH 24/27] fix Convert python's [:,:,:] syntax: ':' => TensorIndex.Colon --- src/Extensions/TorchCs/TorchUtil.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 1c544b9..92f3eab 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -522,7 +522,7 @@ private static string replaceListSlice(string text) List list = new List(); foreach (var str in strs) { if (str.Trim() == "\":\"") { - list.Add("TensorIndex.Ellipsis"); + list.Add("TensorIndex.Colon"); } else if (str.Trim() == "") { list.Add("TensorIndex.Null"); } else if (str.Contains(":")) { From e15553e8b788077be1b1886ff8314c31146a945e Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 9 Mar 2023 10:14:23 +0800 Subject: [PATCH 25/27] Adapt python code torch.arange(B)[:, None, ] == torch.arange(B)[:, None] --- src/Extensions/TorchCs/TorchUtil.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 92f3eab..b74683b 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -523,7 +523,9 @@ private static string replaceListSlice(string text) foreach (var str in strs) { if (str.Trim() == "\":\"") { list.Add("TensorIndex.Colon"); - } else if (str.Trim() == "") { + } else if (str.Trim() == "") { // python code: torch.arange(B)[:, None, ] == torch.arange(B)[:, None] + //list.Add(""); + } else if ( str.Trim() == "null") { list.Add("TensorIndex.Null"); } else if (str.Contains(":")) { var ss = str.Trim().Split(':'); From a625ba941b2cf3772e6f39e0f0f04648f6910ab7 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 9 Mar 2023 13:19:22 +0800 Subject: [PATCH 26/27] Turn off netstandard.cs code style prompt --- src/Extensions/TorchCs/Resources/netstandard.cs | 6 +++++- src/Extensions/TorchCs/TorchCs.csproj | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index d43e93c..892a3ce 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -341,6 +341,7 @@ public static partial class TorchEnumerable } +#pragma warning disable IDE1006 // 命名样式 public static class os { public static void makedirs(string path) @@ -432,10 +433,13 @@ public static string abspath(string path) return Path.GetFullPath(path); } +#pragma warning disable CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 public static string? dirname(string path) { return Path.GetDirectoryName(path); } +#pragma warning restore CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 + public static long getsize(string path) { return new FileInfo(path).Length; @@ -465,7 +469,7 @@ public static void sleep(int s) Thread.Sleep(s * 1000); } } - +#pragma warning restore IDE1006 // 命名样式 } diff --git a/src/Extensions/TorchCs/TorchCs.csproj b/src/Extensions/TorchCs/TorchCs.csproj index c9d78e1..407feb8 100644 --- a/src/Extensions/TorchCs/TorchCs.csproj +++ b/src/Extensions/TorchCs/TorchCs.csproj @@ -12,7 +12,7 @@ - + From b1ca10e8e57288cdc0d19c2c390d94bb5d0331d7 Mon Sep 17 00:00:00 2001 From: linzhijun Date: Thu, 9 Mar 2023 13:55:49 +0800 Subject: [PATCH 27/27] add PythonFile class --- .../TorchCs/Resources/netstandard.cs | 180 +++++++++++++++++- 1 file changed, 176 insertions(+), 4 deletions(-) diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index 892a3ce..ff0350f 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -17,7 +17,11 @@ using System.Security.Cryptography; using static TorchSharp.torch; using System.Linq; +using System.Text; +#pragma warning disable IDE1006 // 命名样式 +#pragma warning disable CS8981 // 该类型名称仅包含小写 ascii 字符。此类名称可能会成为该语言的保留值。 +#pragma warning disable CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 namespace System { public static partial class TorchExtension @@ -341,7 +345,6 @@ public static partial class TorchEnumerable } -#pragma warning disable IDE1006 // 命名样式 public static class os { public static void makedirs(string path) @@ -433,12 +436,10 @@ public static string abspath(string path) return Path.GetFullPath(path); } -#pragma warning disable CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 public static string? dirname(string path) { return Path.GetDirectoryName(path); } -#pragma warning restore CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 public static long getsize(string path) { @@ -469,7 +470,178 @@ public static void sleep(int s) Thread.Sleep(s * 1000); } } -#pragma warning restore IDE1006 // 命名样式 + public class PythonFile + { + private System.IO.FileStream fileStream; + private bool bin; + + public static PythonFile open(string file, string mode = "+", string encoding = "UTF-8") + { + PythonFile result = new PythonFile(); + + if (mode.Contains("+")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.ReadWrite); + if (mode.Contains("a")) + { + result.fileStream.Seek(0, SeekOrigin.End); + } + } else if (mode.Contains("a")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Write); + result.fileStream.Seek(0, SeekOrigin.End); + } else if (mode.Contains("w")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Write); + } else + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Read); + } + result.bin = mode.Contains("b"); + return result; + } + public string[] readline(int size = 1) + { + var read = new System.IO.StreamReader(fileStream); + string[] result = new string[size]; + for (int i = 0; i < size; i++) + { + result[i] = read.ReadLine(); + } + read.ReadToEnd(); + return result; + } + public string readline() + { + var read = new System.IO.StreamReader(fileStream); + string result = read.ReadLine(); + read.ReadToEnd(); + return result; + } + public string read() + { + var read = new System.IO.StreamReader(fileStream); + var r = read.Read(); + read.ReadToEnd(); + return ((char) r).ToString(); + } + + public string read(int size = 1) + { + if (size <= 0) + { + var read = new System.IO.StreamReader(fileStream); + var r = read.ReadToEnd(); + read.ReadToEnd(); + return r; + } else + { + var read = new System.IO.StreamReader(fileStream); + StringBuilder stringBuilder = new StringBuilder(); + for (int i = 0; i < size; i++) + { + var r = read.Read(); + stringBuilder.Append((char) r); + } + read.ReadToEnd(); + return stringBuilder.ToString(); + } + } + + public void write(string txt) + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(txt); + write.Close(); + } + + public void write(double num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(float num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(int num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(long num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + + public void seek(int offset, int whence = 0) + { + if (whence == 0) + { + fileStream.Seek(offset, SeekOrigin.Begin); + } else if (whence == 1) + { + fileStream.Seek(offset, SeekOrigin.Current); + } else if (whence == 2) + { + fileStream.Seek(offset, SeekOrigin.End); + } else + { + throw new Exception("whence is error."); + } + } + + public long tell() + { + return fileStream.Position; + } + + public void close() + { + fileStream.Close(); + } + } } +#pragma warning restore CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 +#pragma warning restore CS8981 // 该类型名称仅包含小写 ascii 字符。此类名称可能会成为该语言的保留值。 +#pragma warning restore IDE1006 // 命名样式 \ No newline at end of file