diff --git a/FSharpWorkspaceShim/FSharpWorkspaceShim.fsproj b/FSharpWorkspaceShim/FSharpWorkspaceShim.fsproj index 41b0ba9c3..be8b7ff92 100644 --- a/FSharpWorkspaceShim/FSharpWorkspaceShim.fsproj +++ b/FSharpWorkspaceShim/FSharpWorkspaceShim.fsproj @@ -9,9 +9,7 @@ false - - - + diff --git a/MLS.Agent/CommandLine/KernelServerCommand.cs b/MLS.Agent/CommandLine/KernelServerCommand.cs index 9b1b16bb6..36d4c261f 100644 --- a/MLS.Agent/CommandLine/KernelServerCommand.cs +++ b/MLS.Agent/CommandLine/KernelServerCommand.cs @@ -5,7 +5,6 @@ using System.CommandLine; using System.Threading.Tasks; using Microsoft.DotNet.Interactive; -using WorkspaceServer.Kernel; namespace MLS.Agent.CommandLine { diff --git a/Microsoft.DotNet.Interactive.FSharp/FSharpKernel.fs b/Microsoft.DotNet.Interactive.FSharp/FSharpKernel.fs index 429709bcc..9aabe49c4 100644 --- a/Microsoft.DotNet.Interactive.FSharp/FSharpKernel.fs +++ b/Microsoft.DotNet.Interactive.FSharp/FSharpKernel.fs @@ -35,10 +35,19 @@ type FSharpKernel() = context.Publish(CodeSubmissionEvaluated(codeSubmission)) context.Complete() } + + let handleCancelCurrentCommand (cancelCurrentCommand: CancelCurrentCommand) (context: KernelInvocationContext) = + async { + let reply = CurrentCommandCancelled(cancelCurrentCommand) + context.Publish(reply) + context.Complete() + } + override __.HandleAsync(command: IKernelCommand, _context: KernelInvocationContext): Task = async { match command with | :? SubmitCode as submitCode -> submitCode.Handler <- fun invocationContext -> (handleSubmitCode submitCode invocationContext) |> Async.StartAsTask :> Task + | :? CancelCurrentCommand as cancelCurrentCommand -> cancelCurrentCommand.Handler <- fun invocationContext -> (handleCancelCurrentCommand cancelCurrentCommand invocationContext) |> Async.StartAsTask :> Task | _ -> () } |> Async.StartAsTask :> Task diff --git a/Microsoft.DotNet.Interactive.Jupyter.Tests/InterruptRequestHandlerTests.cs b/Microsoft.DotNet.Interactive.Jupyter.Tests/InterruptRequestHandlerTests.cs new file mode 100644 index 000000000..4fe4eb32a --- /dev/null +++ b/Microsoft.DotNet.Interactive.Jupyter.Tests/InterruptRequestHandlerTests.cs @@ -0,0 +1,66 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.DotNet.Interactive.Jupyter.Protocol; +using Recipes; +using WorkspaceServer.Kernel; +using Xunit; + +namespace Microsoft.DotNet.Interactive.Jupyter.Tests +{ + public class InterruptRequestHandlerTests + { + private readonly MessageSender _ioPubChannel; + private readonly MessageSender _serverChannel; + private readonly RecordingSocket _serverRecordingSocket; + private readonly RecordingSocket _ioRecordingSocket; + private readonly KernelStatus _kernelStatus; + + public InterruptRequestHandlerTests() + { + var signatureValidator = new SignatureValidator("key", "HMACSHA256"); + _serverRecordingSocket = new RecordingSocket(); + _serverChannel = new MessageSender(_serverRecordingSocket, signatureValidator); + _ioRecordingSocket = new RecordingSocket(); + _ioPubChannel = new MessageSender(_ioRecordingSocket, signatureValidator); + _kernelStatus = new KernelStatus(); + } + + [Fact] + public void cannot_handle_requests_that_are_not_InterruptRequest() + { + var kernel = new CSharpKernel(); + var handler = new InterruptRequestHandler(kernel); + var request = Message.Create(new DisplayData(), null); + Func messageHandling = () => handler.Handle(new JupyterRequestContext(_serverChannel, _ioPubChannel, request, _kernelStatus)); + messageHandling.Should().ThrowExactly(); + } + + [Fact] + public async Task handles_InterruptRequest() + { + var kernel = new CSharpKernel(); + var handler = new InterruptRequestHandler(kernel); + var request = Message.Create(new InterruptRequest(), null); + await handler.Handle(new JupyterRequestContext(_serverChannel, _ioPubChannel, request, _kernelStatus)); + } + + [Fact] + public async Task sends_InterruptReply() + { + var kernel = new CSharpKernel(); + var handler = new InterruptRequestHandler(kernel); + var request = Message.Create(new InterruptRequest(), null); + await handler.Handle(new JupyterRequestContext(_serverChannel, _ioPubChannel, request, _kernelStatus)); + + _serverRecordingSocket.DecodedMessages.SingleOrDefault(message => + message.Contains(MessageTypeValues.InterruptReply)) + .Should() + .NotBeNullOrWhiteSpace(); + } + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive.Jupyter/InterruptRequestHandler.cs b/Microsoft.DotNet.Interactive.Jupyter/InterruptRequestHandler.cs new file mode 100644 index 000000000..cfa710fdb --- /dev/null +++ b/Microsoft.DotNet.Interactive.Jupyter/InterruptRequestHandler.cs @@ -0,0 +1,64 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reactive.Concurrency; +using System.Threading.Tasks; +using Microsoft.DotNet.Interactive.Commands; +using Microsoft.DotNet.Interactive.Events; +using Microsoft.DotNet.Interactive.Jupyter.Protocol; + +namespace Microsoft.DotNet.Interactive.Jupyter +{ + public class InterruptRequestHandler : RequestHandlerBase + { + + public InterruptRequestHandler(IKernel kernel, IScheduler scheduler = null) + : base(kernel, scheduler ?? CurrentThreadScheduler.Instance) + { + + } + + protected override void OnKernelEvent(IKernelEvent @event) + { + switch (@event) + { + case CurrentCommandCancelled kernelInterrupted: + OnExecutionInterrupted(kernelInterrupted); + break; + } + } + + private void OnExecutionInterrupted(CurrentCommandCancelled currentCommandCancelled) + { + if (InFlightRequests.TryRemove(currentCommandCancelled.Command, out var openRequest)) + { + // reply + var interruptReplyPayload = new InterruptReply(); + + // send to server + var interruptReply = Message.CreateResponse( + interruptReplyPayload, + openRequest.Context.Request); + + openRequest.Context.ServerChannel.Send(interruptReply); + openRequest.Context.RequestHandlerStatus.SetAsIdle(); + openRequest.Dispose(); + } + } + + public async Task Handle(JupyterRequestContext context) + { + var interruptRequest = GetJupyterRequest(context); + + context.RequestHandlerStatus.SetAsBusy(); + + var command = new CancelCurrentCommand(); + + var openRequest = new InflightRequest(context, interruptRequest, 0); + + InFlightRequests[command] = openRequest; + + await Kernel.SendAsync(command); + } + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive.Jupyter/IsCompleteRequestHandler.cs b/Microsoft.DotNet.Interactive.Jupyter/IsCompleteRequestHandler.cs index 054fe360f..c0ddf9e70 100644 --- a/Microsoft.DotNet.Interactive.Jupyter/IsCompleteRequestHandler.cs +++ b/Microsoft.DotNet.Interactive.Jupyter/IsCompleteRequestHandler.cs @@ -63,7 +63,5 @@ private void OnKernelEvent(IKernelEvent @event, bool isComplete) openRequest.Dispose(); } } - - } } \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive.Jupyter/JupyterRequestContextHandler.cs b/Microsoft.DotNet.Interactive.Jupyter/JupyterRequestContextHandler.cs index 07dde3925..cd335863e 100644 --- a/Microsoft.DotNet.Interactive.Jupyter/JupyterRequestContextHandler.cs +++ b/Microsoft.DotNet.Interactive.Jupyter/JupyterRequestContextHandler.cs @@ -15,7 +15,8 @@ public class JupyterRequestContextHandler : ICommandHandler Handle( case MessageTypeValues.CompleteRequest: await _completeHandler.Handle(delivery.Command); break; + case MessageTypeValues.InterruptRequest: + await _interruptHandler.Handle(delivery.Command); + break; } return delivery.Complete(); diff --git a/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptReply.cs b/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptReply.cs new file mode 100644 index 000000000..530d73223 --- /dev/null +++ b/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptReply.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.DotNet.Interactive.Jupyter.Protocol +{ + [JupyterMessageType(MessageTypeValues.InterruptReply)] + public class InterruptReply : JupyterMessageContent + { + + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptRequest.cs b/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptRequest.cs new file mode 100644 index 000000000..b34872baf --- /dev/null +++ b/Microsoft.DotNet.Interactive.Jupyter/Protocol/InterruptRequest.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.DotNet.Interactive.Jupyter.Protocol +{ + [JupyterMessageType(MessageTypeValues.InterruptRequest)] + public class InterruptRequest : JupyterMessageContent + { + + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive.Jupyter/Protocol/MessageTypeValues.cs b/Microsoft.DotNet.Interactive.Jupyter/Protocol/MessageTypeValues.cs index 0be82eca5..b7216767f 100644 --- a/Microsoft.DotNet.Interactive.Jupyter/Protocol/MessageTypeValues.cs +++ b/Microsoft.DotNet.Interactive.Jupyter/Protocol/MessageTypeValues.cs @@ -62,5 +62,9 @@ public class MessageTypeValues public const string CommInfoRequest = "comm_info_request"; public const string CommInfoReply = "comm_info_reply"; + + public const string InterruptRequest = "interrupt_request"; + + public const string InterruptReply = "interrupt_reply"; } } diff --git a/Microsoft.DotNet.Interactive/Commands/CancelCurrentCommand.cs b/Microsoft.DotNet.Interactive/Commands/CancelCurrentCommand.cs new file mode 100644 index 000000000..92b17d300 --- /dev/null +++ b/Microsoft.DotNet.Interactive/Commands/CancelCurrentCommand.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.DotNet.Interactive.Commands +{ + public class CancelCurrentCommand : KernelCommandBase + { + + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive/Commands/RequestCompletion.cs b/Microsoft.DotNet.Interactive/Commands/RequestCompletion.cs index ed27b9607..ce991388b 100644 --- a/Microsoft.DotNet.Interactive/Commands/RequestCompletion.cs +++ b/Microsoft.DotNet.Interactive/Commands/RequestCompletion.cs @@ -1,7 +1,5 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -// Copyright (c) .NET Foundation and contributors. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; diff --git a/Microsoft.DotNet.Interactive/Commands/SubmitCode.cs b/Microsoft.DotNet.Interactive/Commands/SubmitCode.cs index a34a6dcb3..2fb4938d4 100644 --- a/Microsoft.DotNet.Interactive/Commands/SubmitCode.cs +++ b/Microsoft.DotNet.Interactive/Commands/SubmitCode.cs @@ -20,6 +20,7 @@ public SubmitCode( public string Code { get; set; } public string TargetKernelName { get; set; } + public SubmissionType SubmissionType { get; } public override string ToString() => $"{base.ToString()}: {Code.TruncateForDisplay()}"; diff --git a/Microsoft.DotNet.Interactive/Events/CurrentCommandCancelled.cs b/Microsoft.DotNet.Interactive/Events/CurrentCommandCancelled.cs new file mode 100644 index 000000000..8dad375e2 --- /dev/null +++ b/Microsoft.DotNet.Interactive/Events/CurrentCommandCancelled.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.DotNet.Interactive.Commands; + +namespace Microsoft.DotNet.Interactive.Events +{ + public class CurrentCommandCancelled:KernelEventBase + { + public CurrentCommandCancelled(IKernelCommand command) : base(command) + { + + } + } +} \ No newline at end of file diff --git a/Microsoft.DotNet.Interactive/KernelStreamClient.cs b/Microsoft.DotNet.Interactive/KernelStreamClient.cs index a0d8affb3..3c1c80298 100644 --- a/Microsoft.DotNet.Interactive/KernelStreamClient.cs +++ b/Microsoft.DotNet.Interactive/KernelStreamClient.cs @@ -18,6 +18,7 @@ public class KernelStreamClient private readonly TextReader _input; private readonly TextWriter _output; private readonly CommandDeserializer _deserializer = new CommandDeserializer(); + private readonly JsonSerializerSettings _jsonSerializerSettings = new JsonSerializerSettings { ContractResolver = new CamelCasePropertyNamesContractResolver() diff --git a/WorkspaceServer.Tests/Kernel/CSharpKernelTests.cs b/WorkspaceServer.Tests/Kernel/CSharpKernelTests.cs index 119dba937..9e3529061 100644 --- a/WorkspaceServer.Tests/Kernel/CSharpKernelTests.cs +++ b/WorkspaceServer.Tests/Kernel/CSharpKernelTests.cs @@ -203,6 +203,28 @@ public async Task it_produces_values_when_executing_Console_output() new DisplayedValueProduced("value three", kernelCommand, new[] { new FormattedValue("text/plain", "value three"), })); } + [Fact] + public async Task it_can_cancel_execution() + { + var kernel = CreateKernel(); + + var submitCodeCommand = new SubmitCode(@"System.Threading.Thread.Sleep(90000000);"); + var codeSubmission = kernel.SendAsync(submitCodeCommand); + var interruptionCommand = new CancelCurrentCommand(); + await kernel.SendAsync(interruptionCommand); + await codeSubmission; + + KernelEvents + .ValuesOnly() + .Single(e => e is CurrentCommandCancelled); + + KernelEvents + .ValuesOnly() + .OfType() + .Should() + .BeEquivalentTo(new CommandFailed(null, interruptionCommand, "Command cancelled")); + } + [Fact] public async Task it_produces_a_final_value_if_the_code_expression_evaluates() { diff --git a/WorkspaceServer/Kernel/CSharpKernel.cs b/WorkspaceServer/Kernel/CSharpKernel.cs index f21f7a56e..a2657dca1 100644 --- a/WorkspaceServer/Kernel/CSharpKernel.cs +++ b/WorkspaceServer/Kernel/CSharpKernel.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Reflection; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Completion; @@ -33,16 +34,20 @@ public class CSharpKernel : KernelBase private static readonly MethodInfo _hasReturnValueMethod = typeof(Script) .GetMethod("HasReturnValue", BindingFlags.Instance | BindingFlags.NonPublic); - protected CSharpParseOptions ParseOptions = new CSharpParseOptions(LanguageVersion.Default, kind: SourceCodeKind.Script); + protected CSharpParseOptions ParseOptions = + new CSharpParseOptions(LanguageVersion.Default, kind: SourceCodeKind.Script); private ScriptState _scriptState; protected ScriptOptions ScriptOptions; private ImmutableArray _metadataReferences; private WorkspaceFixture _fixture; + private CancellationTokenSource _cancellationSource; + private readonly object _cancellationSourceLock = new object(); public CSharpKernel() { + _cancellationSource = new CancellationTokenSource(); _metadataReferences = ImmutableArray.Empty; SetupScriptOptions(); Name = KernelName; @@ -66,6 +71,8 @@ private void SetupScriptOptions() typeof(CSharpKernel).Assembly, typeof(PocketView).Assembly, typeof(XPlot.Plotly.PlotlyChart).Assembly); + + } protected override async Task HandleAsync( @@ -87,6 +94,13 @@ protected override async Task HandleAsync( await HandleRequestCompletion(requestCompletion, invocationContext, _scriptState); }; break; + + case CancelCurrentCommand interruptExecution: + interruptExecution.Handler = async invocationContext => + { + await HandleCancelCurrentCommand(interruptExecution, invocationContext); + }; + break; } } @@ -95,12 +109,31 @@ public Task IsCompleteSubmissionAsync(string code) var syntaxTree = SyntaxFactory.ParseSyntaxTree(code, ParseOptions); return Task.FromResult(SyntaxFactory.IsCompleteSubmission(syntaxTree)); } - - private async Task HandleSubmitCode( - SubmitCode submitCode, + private async Task HandleCancelCurrentCommand( + CancelCurrentCommand cancelCurrentCommand, KernelInvocationContext context) { + var reply = new CurrentCommandCancelled(cancelCurrentCommand); + lock (_cancellationSourceLock) + { + _cancellationSource.Cancel(); + _cancellationSource = new CancellationTokenSource(); + } + + context.Publish(reply); + context.Complete(); + } + + private async Task HandleSubmitCode( + SubmitCode submitCode, + KernelInvocationContext context) + { + CancellationTokenSource cancellationSource; + lock (_cancellationSourceLock) + { + cancellationSource = _cancellationSource; + } var codeSubmissionReceived = new CodeSubmissionReceived( submitCode.Code, submitCode); @@ -109,7 +142,7 @@ private async Task HandleSubmitCode( var code = submitCode.Code; var isComplete = await IsCompleteSubmissionAsync(submitCode.Code); - if(isComplete) + if (isComplete) { context.Publish(new CompleteCodeSubmissionReceived(submitCode)); } @@ -124,64 +157,79 @@ private async Task HandleSubmitCode( } Exception exception = null; - using var console = await ConsoleOutput.Capture(); using var _ = console.SubscribeToStandardOutput(std => PublishOutput(std, context, submitCode)); + var scriptState = _scriptState; - try + if (!cancellationSource.IsCancellationRequested) { - if (_scriptState == null) + try { - _scriptState = await CSharpScript.RunAsync( - code, - ScriptOptions); + if (scriptState == null) + { + scriptState = await CSharpScript.RunAsync( + code, + ScriptOptions, + cancellationToken: cancellationSource.Token) + .UntilCancelled(cancellationSource.Token); + } + else + { + scriptState = await _scriptState.ContinueWithAsync( + code, + ScriptOptions, + e => + { + exception = e; + return true; + }, + cancellationToken: cancellationSource.Token) + .UntilCancelled(cancellationSource.Token); + } } - else + catch (Exception e) { - _scriptState = await _scriptState.ContinueWithAsync( - code, - ScriptOptions, - e => - { - exception = e; - return true; - }); + exception = e; } } - catch (Exception e) - { - exception = e; - } - if (exception != null) + if (!cancellationSource.IsCancellationRequested) { - string message = null; + _scriptState = scriptState; + if (exception != null) + { + string message = null; + + if (exception is CompilationErrorException compilationError) + { + message = + string.Join(Environment.NewLine, + compilationError.Diagnostics.Select(d => d.ToString())); + } - if (exception is CompilationErrorException compilationError) + context.Publish(new CommandFailed(exception, submitCode, message)); + } + else { - message = - string.Join(Environment.NewLine, - compilationError.Diagnostics.Select(d => d.ToString())); + if (_scriptState != null && HasReturnValue) + { + var formattedValues = FormattedValue.FromObject(_scriptState.ReturnValue); + context.Publish( + new ReturnValueProduced( + _scriptState.ReturnValue, + submitCode, + formattedValues)); + } + + context.Publish(new CodeSubmissionEvaluated(submitCode)); } - - context.Publish(new CommandFailed(exception, submitCode, message)); - context.Complete(); } else { - if (HasReturnValue) - { - var formattedValues = FormattedValue.FromObject(_scriptState.ReturnValue); - context.Publish( - new ReturnValueProduced( - _scriptState.ReturnValue, - submitCode, - formattedValues)); - } - - context.Publish(new CodeSubmissionEvaluated(submitCode)); - context.Complete(); + context.Publish(new CommandFailed(null, submitCode, "Command cancelled")); } + + context.Complete(); } private void PublishOutput( @@ -217,7 +265,7 @@ private async Task HandleRequestCompletion( context.Publish(new CompletionRequestCompleted(completionList, requestCompletion)); } - public void AddMetatadaReferences(IEnumerable references) + public void AddMetadataReferences(IEnumerable references) { _metadataReferences.AddRange(references); ScriptOptions = ScriptOptions.AddReferences(references); diff --git a/WorkspaceServer/Kernel/CSharpKernelExtensions.cs b/WorkspaceServer/Kernel/CSharpKernelExtensions.cs index 2090fbaeb..72bc8a7e8 100644 --- a/WorkspaceServer/Kernel/CSharpKernelExtensions.cs +++ b/WorkspaceServer/Kernel/CSharpKernelExtensions.cs @@ -80,7 +80,7 @@ public static CSharpKernel UseNugetDirective(this CSharpKernel kernel, INativeAs } } - kernel.AddMetatadaReferences(refs); + kernel.AddMetadataReferences(refs); } context.Publish(new NuGetPackageAdded(package)); diff --git a/WorkspaceServer/Kernel/ScriptExecutionExtensions.cs b/WorkspaceServer/Kernel/ScriptExecutionExtensions.cs new file mode 100644 index 000000000..3e8c45a2f --- /dev/null +++ b/WorkspaceServer/Kernel/ScriptExecutionExtensions.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.Scripting; + +namespace WorkspaceServer.Kernel +{ + internal static class ScriptExecutionExtensions + { + public static async Task> UntilCancelled( + this Task> source, + CancellationToken cancellationToken) + { + var completed = await Task.WhenAny( + source, + Task.Run(async () => + { + while (!cancellationToken.IsCancellationRequested) + { + await Task.Delay(100, cancellationToken); + } + return (ScriptState) null; + }, cancellationToken)); + + + return completed.Result; + + } + } +} \ No newline at end of file